mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
Add TF whisper (#19378)
* simplify loop * add featur extractor * add model * start conversion * add dropout * initial commit of test files * copnversion for all models * update processor for correct padding * update feature extraction * update integration test logits match * fmnt: off for the logits * on the fly mel bank * small nit * update test * update tokenizer * nit feature extraction * update * update tokenizer test * adds logit processor and update tokenizer to get supress tokens * style * clean convert * revert to original modeling tf utils * Update * update * nit * clean convert file * update tests and nits * quality * slow generation test * ffn_dim to allow customization * update readme * add to toctreee * start fixing integration tests * update tests and code * fix feature extractor * fix config tests common * update code to fix tests * fix feature exctractor * nit feature extraction * update test for new feature extractor * style * add absrtact * large logits wioth custom decoder input ids * wraap around is otrch available * fix feature extractor * correct logits for whisper small.en * nit * fix encoder_attentino_mask * some fixes * remove unnecessary inputs * nits * add normalizer file * update etst tokenization * fix attention mask not defined * fix generate * remove uncoder attention mask useless * update test modeling whisper * update condfig to add second non supress tokens * nits on feature exrtactor * nit for test tokenizers * update etsts * update tests * update tokenization test * fixup * invalidated hf token. Clean convert openai to whisper * fix logit tests * fixup * Add model to README * Fix doc tests * clean merge * revert toc_tree changes * remove useless LogitProcessor * Update whisper .mdx * update config file doc * update configuration docstring * update test tokenization * update test tokenization * update tokenization whisper Added copied from where needed * update feature extraction * nit test name * style * quality * remove get suppress tokens and update non_speech tokens global variables * Update src/transformers/models/whisper/feature_extraction_whisper.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * clean modeling whisper and test Removed the attention mask arguments that are deprecated * fix large test * Add multilingual audio test, and translate test * style * fix larg multilingual test * nits * add copied from for attention layer * remove attention masks in doc * add english normalizer * Update docs/source/en/model_doc/whisper.mdx Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * update tokenization test * remove copied from in whisper attention : no bias in k_proj only * wrap around dependencies in english normalizer * style * correct import generation logits * for now, wrap feature extractor with torch * remove torch depencies for feature extraction and style * Update src/transformers/models/whisper/convert_openai_whisper_to_tfms.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/models/whisper/configuration_whisper.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/whisper.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * fixup * nit * update logitds * style * nit * nits and fix final tests * add `is_more_itertools_available` to utils * quality * add begin supress tokens, supress tokens to generate args and config * clean supressTokensLogitProcessor in generation logits * Nit naming * add supressTokensAtBegin * udpate tests, supress tokens to None or correct values * nit and style * update RAG to fit test and generate_logit * add copy pasted statment on english normalizer * add arguments to config_common_kwargs * Update src/transformers/generation_utils.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update src/transformers/generation_logits_process.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * revert changes based on reviews * update doc and nits * Update src/transformers/models/whisper/configuration_whisper.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * more nits * last nits * update test configuration common * add BART name in decoder attention mask documentation * Update src/transformers/models/whisper/modeling_whisper.py Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * style * nit * nit * add english.json file to git * nits on documentation * nit * nits * last styling * add main toctree file * remove sentence piece dependency * clean init file * fix tokenizer that has no dependencies on sentencepiece * update whisper init file, nit * remove english.json file * add get decoder prompt id * All weights loading * Remove hanging pdb * Fixup and tidy up * Use same copied from as PT model * Remove whitespace changes * Remove torch references * Tie embeddings * Remove logits processor input to generate * Update logit values * revert changes and add forced logit processor * nit * clean normalizer * remove protected * Add logit processors and update generation code & tests * Some tidy up * Update docstring * update * update based on review * Update src/transformers/models/whisper/configuration_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/whisper/configuration_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update to reflect changes on the PT model branch * Tidy up * Remove extra whitespace * Fix test - make input ids small enough we can append * Include upstream changes on main * PR comments - add batch tests, remove comments & defaults * Fix model output imports * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation_tf_logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/models/whisper/test_modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update docstring example * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Remove changes to adjust_logits_during_generation function * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Tidy up imports that don't require TF * Update tests - skip and no more skip * Update tests/generation/test_generation_tf_logits_process.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/whisper/modeling_tf_whisper.py * Update src/transformers/models/whisper/modeling_tf_whisper.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> * Add training flags * Add (skipped) XLA generation tests * Add embedding correctness test * Add constant ids for generation tests * Make logits finding a bit tidier * Remove unused args * xla generation enabled * Don't skip XLA tests anymore * Fix tests - add position ids to expected signature and update rag generation * Undo method reorder * Remove added whitespace * Remove copy-paste gradient checkopint ref * Remove * Trigger CI - (issue with refs when pulling) Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: NielsRogge <niels.rogge1@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> Co-authored-by: Joao Gante <joao@huggingface.co>
This commit is contained in:
parent
af69360bf9
commit
e3f028f3af
@ -330,7 +330,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
|
||||
| Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Whisper | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Whisper | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| XGLM | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
|
@ -27,7 +27,7 @@ Tips:
|
||||
- The architecture follows a classic encoder-decoder architecture, which means that it relies on the [`~generation_utils.GenerationMixin.generate`] function for inference.
|
||||
- One can use [`WhisperProcessor`] to prepare audio for the model, and decode the predicted ID's back into text.
|
||||
|
||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ).
|
||||
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ). The Tensorflow version of this model was contributed by [amyeroberts](https://huggingface.co/amyeroberts).
|
||||
The original code can be found [here](https://github.com/openai/whisper).
|
||||
|
||||
|
||||
@ -66,3 +66,14 @@ The original code can be found [here](https://github.com/openai/whisper).
|
||||
|
||||
[[autodoc]] WhisperForConditionalGeneration
|
||||
- forward
|
||||
|
||||
|
||||
## TFWhisperModel
|
||||
|
||||
[[autodoc]] TFWhisperModel
|
||||
- call
|
||||
|
||||
## TFWhisperForConditionalGeneration
|
||||
|
||||
[[autodoc]] TFWhisperForConditionalGeneration
|
||||
- call
|
||||
|
@ -2754,6 +2754,14 @@ else:
|
||||
"TFWav2Vec2PreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.whisper"].extend(
|
||||
[
|
||||
"TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFWhisperForConditionalGeneration",
|
||||
"TFWhisperModel",
|
||||
"TFWhisperPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.xglm"].extend(
|
||||
[
|
||||
"TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -5303,6 +5311,12 @@ if TYPE_CHECKING:
|
||||
TFWav2Vec2Model,
|
||||
TFWav2Vec2PreTrainedModel,
|
||||
)
|
||||
from .models.whisper import (
|
||||
TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFWhisperForConditionalGeneration,
|
||||
TFWhisperModel,
|
||||
TFWhisperPreTrainedModel,
|
||||
)
|
||||
from .models.xglm import (
|
||||
TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFXGLMForCausalLM,
|
||||
|
@ -504,3 +504,84 @@ class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor):
|
||||
axis=-1,
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
class TFSuppressTokensAtBeginLogitsProcessor(TFLogitsProcessor):
|
||||
r"""
|
||||
[`TFSuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts
|
||||
generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
|
||||
sampled at the begining of the generation.
|
||||
"""
|
||||
|
||||
def __init__(self, begin_suppress_tokens, begin_index):
|
||||
self.begin_suppress_tokens = list(begin_suppress_tokens)
|
||||
self.begin_index = begin_index
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
scores = tf.cond(
|
||||
tf.equal(cur_len, self.begin_index),
|
||||
lambda: tf.tensor_scatter_nd_update(
|
||||
scores,
|
||||
indices=[[i, token] for i in range(scores.shape[0]) for token in self.begin_suppress_tokens],
|
||||
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))],
|
||||
),
|
||||
lambda: scores,
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
class TFSuppressTokensLogitsProcessor(TFLogitsProcessor):
|
||||
r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
|
||||
are not sampled."""
|
||||
|
||||
def __init__(self, suppress_tokens):
|
||||
self.suppress_tokens = list(suppress_tokens)
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
scores = tf.tensor_scatter_nd_update(
|
||||
scores,
|
||||
indices=[[i, token] for i in range(scores.shape[0]) for token in self.suppress_tokens],
|
||||
updates=[-float("inf") for _ in range(scores.shape[0] * len(self.suppress_tokens))],
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
class TFForceTokensLogitsProcessor(TFLogitsProcessor):
|
||||
r"""This processor can be used to force a list of tokens. The processor will set their log probs to `0` and all
|
||||
other tokens to `-inf` so that they are sampled at their corresponding index."""
|
||||
|
||||
def __init__(self, force_token_map):
|
||||
force_token_map = dict(force_token_map)
|
||||
# Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
|
||||
# index of the array corresponds to the index of the token to be forced, for XLA compatibility.
|
||||
# Indexes without forced tokens will have an negative value.
|
||||
force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1
|
||||
for index, token in force_token_map.items():
|
||||
force_token_array[index] = token
|
||||
self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)
|
||||
|
||||
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
|
||||
def _force_token(generation_idx):
|
||||
batch_size = scores.shape[0]
|
||||
current_token = self.force_token_array[generation_idx]
|
||||
|
||||
new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float("inf")
|
||||
indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
|
||||
updates = tf.zeros((batch_size,), dtype=scores.dtype)
|
||||
new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
|
||||
return new_scores
|
||||
|
||||
scores = tf.cond(
|
||||
tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]),
|
||||
# If the current length is geq than the length of force_token_array, the processor does nothing.
|
||||
lambda: tf.identity(scores),
|
||||
# Otherwise, it may force a certain token.
|
||||
lambda: tf.cond(
|
||||
tf.greater_equal(self.force_token_array[cur_len], 0),
|
||||
# Only valid (positive) tokens are forced
|
||||
lambda: _force_token(cur_len),
|
||||
# Otherwise, the processor does nothing.
|
||||
lambda: scores,
|
||||
),
|
||||
)
|
||||
return scores
|
||||
|
@ -26,11 +26,14 @@ from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
|
||||
from .generation_tf_logits_process import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFForceTokensLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFSuppressTokensAtBeginLogitsProcessor,
|
||||
TFSuppressTokensLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
@ -401,6 +404,9 @@ class TFGenerationMixin:
|
||||
return_dict_in_generate=None,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
suppress_tokens: Optional[List[int]] = None,
|
||||
begin_suppress_tokens: Optional[List[int]] = None,
|
||||
forced_decoder_ids: Optional[List[int]] = None,
|
||||
**model_kwargs,
|
||||
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
|
||||
r"""
|
||||
@ -494,6 +500,14 @@ class TFGenerationMixin:
|
||||
the target language token.
|
||||
forced_eos_token_id (`int`, *optional*):
|
||||
The id of the token to force as the last generated token when `max_length` is reached.
|
||||
suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`):
|
||||
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set
|
||||
their log probs to `-inf` so that they are not sampled.
|
||||
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
|
||||
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
|
||||
logit processor will set their log probs to `-inf` so that they are not sampled.
|
||||
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
|
||||
A list of tokens that will be forced as beginning tokens, before sampling.
|
||||
model_specific_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model.
|
||||
|
||||
@ -609,6 +623,9 @@ class TFGenerationMixin:
|
||||
return_dict_in_generate=return_dict_in_generate,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
suppress_tokens=suppress_tokens,
|
||||
begin_suppress_tokens=begin_suppress_tokens,
|
||||
forced_decoder_ids=forced_decoder_ids,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@ -648,6 +665,12 @@ class TFGenerationMixin:
|
||||
forced_eos_token_id = (
|
||||
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
|
||||
)
|
||||
suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens
|
||||
begin_suppress_tokens = (
|
||||
begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens
|
||||
)
|
||||
if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"):
|
||||
forced_decoder_ids = self.config.forced_decoder_ids
|
||||
|
||||
output_scores = output_scores if output_scores is not None else self.config.output_scores
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
@ -1368,6 +1391,9 @@ class TFGenerationMixin:
|
||||
return_dict_in_generate=None,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
suppress_tokens=None,
|
||||
begin_suppress_tokens=None,
|
||||
forced_decoder_ids=None,
|
||||
**model_kwargs,
|
||||
) -> Union[TFGreedySearchOutput, TFSampleOutput, TFBeamSearchOutput, TFBeamSampleOutput, tf.Tensor]:
|
||||
r"""
|
||||
@ -1461,6 +1487,15 @@ class TFGenerationMixin:
|
||||
the target language token.
|
||||
forced_eos_token_id (`int`, *optional*):
|
||||
The id of the token to force as the last generated token when `max_length` is reached.
|
||||
suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`):
|
||||
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set
|
||||
their log probs to `-inf` so that they are not sampled.
|
||||
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
|
||||
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
|
||||
logit processor will set their log probs to `-inf` so that they are not sampled.
|
||||
forced_decoder_ids (`List[int]`, *optional*, defaults to `model.config.forced_decoder_ids`):
|
||||
A list of tokens that will be forced as beginning tokens.
|
||||
|
||||
model_kwargs:
|
||||
Additional model specific kwargs will be forwarded to the `call` function of the model.
|
||||
|
||||
@ -1695,12 +1730,16 @@ class TFGenerationMixin:
|
||||
logits_processor = self._get_logits_processor(
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
input_ids_seq_length=input_ids_seq_length,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
eos_token_id=eos_token_id,
|
||||
forced_bos_token_id=forced_bos_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
suppress_tokens=suppress_tokens,
|
||||
begin_suppress_tokens=begin_suppress_tokens,
|
||||
forced_decoder_ids=forced_decoder_ids,
|
||||
)
|
||||
|
||||
# 9. go into different generation modes
|
||||
@ -1994,7 +2033,7 @@ class TFGenerationMixin:
|
||||
def _initialize_past(past, num_padding_values, batch_axis):
|
||||
"""initialize past with zeros -- the structure depends on `batch_axis`"""
|
||||
if batch_axis == 0:
|
||||
padding_values = tf.scatter_nd(indices=[[2, 1]], updates=[num_padding_values], shape=(4, 2))
|
||||
padding_values = tf.constant([[0, 0], [0, 0], [0, num_padding_values], [0, 0]], dtype=tf.int32)
|
||||
new_past = ()
|
||||
for past_layer in past:
|
||||
new_past_layer = list(past_layer)
|
||||
@ -2099,12 +2138,16 @@ class TFGenerationMixin:
|
||||
self,
|
||||
repetition_penalty: float,
|
||||
no_repeat_ngram_size: int,
|
||||
input_ids_seq_length: int,
|
||||
bad_words_ids: List[List[int]],
|
||||
min_length: int,
|
||||
max_length: int,
|
||||
eos_token_id: int,
|
||||
forced_bos_token_id: int,
|
||||
forced_eos_token_id: int,
|
||||
suppress_tokens: Optional[List[int]] = None,
|
||||
begin_suppress_tokens: Optional[List[int]] = None,
|
||||
forced_decoder_ids: Optional[List[int]] = None,
|
||||
) -> TFLogitsProcessorList:
|
||||
"""
|
||||
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
|
||||
@ -2118,6 +2161,12 @@ class TFGenerationMixin:
|
||||
)
|
||||
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
|
||||
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
||||
suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens
|
||||
begin_suppress_tokens = (
|
||||
begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens
|
||||
)
|
||||
if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"):
|
||||
forced_decoder_ids = self.config.forced_decoder_ids
|
||||
|
||||
# instantiate processors list
|
||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||
@ -2132,7 +2181,16 @@ class TFGenerationMixin:
|
||||
processors.append(TFForcedBOSTokenLogitsProcessor(forced_bos_token_id))
|
||||
if forced_eos_token_id is not None:
|
||||
processors.append(TFForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
|
||||
|
||||
if suppress_tokens is not None:
|
||||
processors.append(TFSuppressTokensLogitsProcessor(suppress_tokens))
|
||||
if begin_suppress_tokens is not None:
|
||||
begin_index = input_ids_seq_length
|
||||
begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1
|
||||
if forced_decoder_ids is not None:
|
||||
begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced
|
||||
processors.append(TFSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index))
|
||||
if forced_decoder_ids is not None:
|
||||
processors.append(TFForceTokensLogitsProcessor(forced_decoder_ids))
|
||||
return processors
|
||||
|
||||
def greedy_search(
|
||||
|
@ -80,6 +80,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
("vit", "TFViTModel"),
|
||||
("vit_mae", "TFViTMAEModel"),
|
||||
("wav2vec2", "TFWav2Vec2Model"),
|
||||
("whisper", "TFWhisperModel"),
|
||||
("xglm", "TFXGLMModel"),
|
||||
("xlm", "TFXLMModel"),
|
||||
("xlm-roberta", "TFXLMRobertaModel"),
|
||||
@ -145,6 +146,7 @@ TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
("t5", "TFT5ForConditionalGeneration"),
|
||||
("tapas", "TFTapasForMaskedLM"),
|
||||
("transfo-xl", "TFTransfoXLLMHeadModel"),
|
||||
("whisper", "TFWhisperForConditionalGeneration"),
|
||||
("xlm", "TFXLMWithLMHeadModel"),
|
||||
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
|
||||
("xlnet", "TFXLNetLMHeadModel"),
|
||||
@ -253,6 +255,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
||||
("whisper", "TFWhisperForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -1262,6 +1262,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
|
||||
eos_token_id=eos_token_id,
|
||||
forced_bos_token_id=None,
|
||||
forced_eos_token_id=None,
|
||||
input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
|
||||
)
|
||||
model_kwargs["attention_mask"] = context_attention_mask
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
||||
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -41,6 +41,18 @@ else:
|
||||
"WhisperPreTrainedModel",
|
||||
]
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_tf_whisper"] = [
|
||||
"TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFWhisperForConditionalGeneration",
|
||||
"TFWhisperModel",
|
||||
"TFWhisperPreTrainedModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig
|
||||
@ -61,6 +73,19 @@ if TYPE_CHECKING:
|
||||
WhisperPreTrainedModel,
|
||||
)
|
||||
|
||||
try:
|
||||
if not is_tf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_tf_whisper import (
|
||||
TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFWhisperForConditionalGeneration,
|
||||
TFWhisperModel,
|
||||
TFWhisperPreTrainedModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
@ -218,7 +218,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
padding: Optional[str] = "max_length",
|
||||
max_length: Optional[int] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
@ -262,19 +261,6 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
The value that is used to fill the padding values / vectors.
|
||||
"""
|
||||
|
||||
if sampling_rate is not None:
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(
|
||||
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
|
||||
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
|
||||
f" {self.sampling_rate} and not {sampling_rate}."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_speech, (list, tuple))
|
||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||
|
1401
src/transformers/models/whisper/modeling_tf_whisper.py
Normal file
1401
src/transformers/models/whisper/modeling_tf_whisper.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -2394,6 +2394,30 @@ class TFWav2Vec2PreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_WHISPER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFWhisperForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFWhisperModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFWhisperPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_XGLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -29,11 +29,14 @@ if is_tf_available():
|
||||
from transformers.generation_tf_logits_process import (
|
||||
TFForcedBOSTokenLogitsProcessor,
|
||||
TFForcedEOSTokenLogitsProcessor,
|
||||
TFForceTokensLogitsProcessor,
|
||||
TFLogitsProcessorList,
|
||||
TFMinLengthLogitsProcessor,
|
||||
TFNoBadWordsLogitsProcessor,
|
||||
TFNoRepeatNGramLogitsProcessor,
|
||||
TFRepetitionPenaltyLogitsProcessor,
|
||||
TFSuppressTokensAtBeginLogitsProcessor,
|
||||
TFSuppressTokensLogitsProcessor,
|
||||
TFTemperatureLogitsWarper,
|
||||
TFTopKLogitsWarper,
|
||||
TFTopPLogitsWarper,
|
||||
@ -331,6 +334,86 @@ class TFLogitsProcessorTest(unittest.TestCase):
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_suppress_tokens_at_begin_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
|
||||
begin_suppress_tokens = [1, 2, 3]
|
||||
begin_index = 5
|
||||
|
||||
logits_processor = TFSuppressTokensAtBeginLogitsProcessor(
|
||||
begin_suppress_tokens=begin_suppress_tokens, begin_index=begin_index
|
||||
)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# Check that no scores are suppressed if begin_index is not reached
|
||||
cur_len = 4
|
||||
input_ids = tf.convert_to_tensor([[11, 17, 15, 8], [14, 0, 19, 5], [13, 11, 18, 19], [11, 12, 16, 15]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
# Check that scores are suppressed if begin_index is reached
|
||||
cur_len = 5
|
||||
input_ids = tf.convert_to_tensor([[5, 5, 5, 0, 17], [18, 1, 9, 14, 17], [18, 6, 8, 15, 19], [8, 12, 17, 1, 2]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertTrue(tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, begin_suppress_tokens, axis=1))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_suppress_tokens_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
|
||||
suppress_tokens = [1, 3, 5]
|
||||
keep_tokens = [i for i in range(vocab_size) if i not in suppress_tokens]
|
||||
|
||||
logits_processor = TFSuppressTokensLogitsProcessor(suppress_tokens=suppress_tokens)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# Check that suppress_tokens are suppressed and others are not
|
||||
cur_len = 5
|
||||
input_ids = tf.convert_to_tensor([[0, 10, 19, 6, 3], [17, 4, 8, 17, 2], [7, 1, 11, 6, 15], [5, 8, 13, 16, 0]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertTrue(tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, suppress_tokens, axis=1))))
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf(tf.gather(scores, keep_tokens, axis=1))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_force_tokens_logits_processor(self, use_xla):
|
||||
vocab_size = 20
|
||||
batch_size = 4
|
||||
|
||||
force_token_map = {1: 2, 3: 2}
|
||||
|
||||
logits_processor = TFForceTokensLogitsProcessor(force_token_map=force_token_map)
|
||||
if use_xla:
|
||||
logits_processor = tf.function(logits_processor, jit_compile=True)
|
||||
|
||||
# check that if the cur_len is contained in the force_token_map, the logits are the same
|
||||
# for all tokens except the one the force_token_map points to
|
||||
cur_len = 1
|
||||
input_ids = tf.convert_to_tensor([[11], [7], [5], [15]])
|
||||
ids_tensor((batch_size, cur_len), vocab_size=20)
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
tf.debugging.assert_near(tf.gather(scores, [force_token_map[cur_len]], axis=1), 0.0)
|
||||
|
||||
non_forced_inds = [i for i in range(vocab_size) if i != force_token_map[cur_len]]
|
||||
self.assertTrue(
|
||||
tf.math.reduce_all(tf.math.is_inf(tf.gather(scores, [non_forced_inds], axis=1))),
|
||||
)
|
||||
|
||||
# check that if the cur_len is not contained in the force_token_map, the logits are not modified
|
||||
cur_len = 2
|
||||
input_ids = tf.convert_to_tensor([[2, 19], [19, 15], [4, 9], [7, 6]])
|
||||
scores = self._get_uniform_logits(batch_size, vocab_size)
|
||||
scores = logits_processor(input_ids, scores, cur_len)
|
||||
self.assertFalse(tf.math.reduce_any(tf.math.is_inf((scores))))
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_processor_list(self, use_xla):
|
||||
# TODO (Joao): reintroduce TFNoRepeatNGramLogitsProcessor when it gets compatible with XLA
|
||||
|
983
tests/models/whisper/test_modeling_tf_whisper.py
Normal file
983
tests/models/whisper/test_modeling_tf_whisper.py
Normal file
@ -0,0 +1,983 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TensorFlow Whisper model. """
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import WhisperConfig, WhisperFeatureExtractor, WhisperProcessor
|
||||
from transformers.testing_utils import is_tf_available, require_tf, require_tokenizers, slow
|
||||
from transformers.utils import cached_property
|
||||
from transformers.utils.import_utils import is_datasets_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_datasets_available():
|
||||
import datasets
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFWhisperForConditionalGeneration, TFWhisperModel, set_seed
|
||||
from transformers.models.whisper.modeling_tf_whisper import TFWhisperDecoder, TFWhisperEncoder
|
||||
|
||||
|
||||
def prepare_whisper_inputs_dict(
|
||||
config,
|
||||
input_features,
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = tf.where(decoder_input_ids != config.pad_token_id, 1, 0)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_features": input_features,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWhisperModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=60,
|
||||
is_training=True,
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
input_channels=1,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
max_source_positions=30,
|
||||
max_target_positions=60,
|
||||
bos_token_id=98,
|
||||
eos_token_id=98,
|
||||
pad_token_id=0,
|
||||
num_mel_bins=80,
|
||||
decoder_start_token_id=85,
|
||||
num_conv_layers=1,
|
||||
suppress_tokens=None,
|
||||
begin_suppress_tokens=None,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.input_channels = input_channels
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.max_source_positions = max_source_positions
|
||||
self.max_target_positions = max_target_positions
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.decoder_start_token_id = decoder_start_token_id
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.suppress_tokens = suppress_tokens
|
||||
self.begin_suppress_tokens = begin_suppress_tokens
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size)
|
||||
|
||||
decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
config = self.get_config()
|
||||
inputs_dict = prepare_whisper_inputs_dict(
|
||||
config,
|
||||
attention_mask=None,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
)
|
||||
return config, inputs_dict
|
||||
|
||||
def get_config(self):
|
||||
return WhisperConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
decoder_layers=self.num_hidden_layers,
|
||||
encoder_attention_heads=self.num_attention_heads,
|
||||
decoder_attention_heads=self.num_attention_heads,
|
||||
input_channels=self.input_channels,
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
max_source_positions=self.max_source_positions,
|
||||
max_target_positions=self.max_target_positions,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
decoder_ffn_dim=self.hidden_size,
|
||||
encoder_ffn_dim=self.hidden_size,
|
||||
decoder_start_token_id=self.decoder_start_token_id,
|
||||
suppress_tokens=self.suppress_tokens,
|
||||
begin_suppress_tokens=self.begin_suppress_tokens,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
return config, inputs_dict
|
||||
|
||||
def get_subsampled_output_lengths(self, input_lengths):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
|
||||
for i in range(self.num_conv_layers):
|
||||
input_lengths = (input_lengths - 1) // 2 + 1
|
||||
|
||||
return input_lengths
|
||||
|
||||
def create_and_check_model_forward(self, config, inputs_dict):
|
||||
model = TFWhisperModel(config=config)
|
||||
|
||||
input_features = inputs_dict["input_features"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
|
||||
# first forward pass
|
||||
last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||
|
||||
self.parent.assertTrue(last_hidden_state.shape, (13, 7, 16))
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||
model = TFWhisperModel(config=config).get_decoder()
|
||||
# take a slice so we're shorter than the seqeuence length and can append later
|
||||
input_ids = inputs_dict["decoder_input_ids"][:, :-10]
|
||||
attention_mask = inputs_dict["decoder_attention_mask"][:, :-10]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
output, past_key_values = outputs.to_tuple()
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_token = ids_tensor((self.batch_size, 3), config.vocab_size)
|
||||
next_tokens = tf.where(next_token <= 2, 2, next_token)
|
||||
next_attn_mask = ids_tensor((self.batch_size, 3), 2)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = np.random.randint(0, output_from_past.shape[-1])
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
self.parent.assertTrue(np.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2))
|
||||
|
||||
def check_encoder_decoder_model_standalone(self, config, inputs_dict):
|
||||
model = TFWhisperModel(config=config)
|
||||
outputs = model(**inputs_dict)
|
||||
|
||||
encoder_last_hidden_state = outputs.encoder_last_hidden_state
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
encoder = model.get_encoder()
|
||||
encoder.save_pretrained(tmpdirname)
|
||||
encoder = TFWhisperEncoder.from_pretrained(tmpdirname)
|
||||
|
||||
encoder_last_hidden_state_2 = encoder(inputs_dict["input_features"])[0]
|
||||
|
||||
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max() < 1e-3)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
decoder = model.get_decoder()
|
||||
decoder.save_pretrained(tmpdirname)
|
||||
decoder = TFWhisperDecoder.from_pretrained(tmpdirname)
|
||||
|
||||
last_hidden_state_2 = decoder(
|
||||
input_ids=inputs_dict["decoder_input_ids"],
|
||||
attention_mask=inputs_dict["decoder_attention_mask"],
|
||||
encoder_hidden_states=encoder_last_hidden_state,
|
||||
)[0]
|
||||
|
||||
self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max() < 1e-3)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWhisperModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFWhisperModel, TFWhisperForConditionalGeneration) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFWhisperForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = False
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
test_onnx = False
|
||||
|
||||
input_name = "input_features"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFWhisperModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=WhisperConfig)
|
||||
self.maxDiff = 3000
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_save_load_strict(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
model(model.dummy_inputs)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=False)
|
||||
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
|
||||
self.assertEqual(info["missing_keys"], [])
|
||||
|
||||
def test_model_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_forward(*config_and_inputs)
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
def _get_input_ids_and_config(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict[self.input_name]
|
||||
|
||||
# cut to half length & take max batch_size 3
|
||||
max_batch_size = 3
|
||||
input_ids = input_ids[:max_batch_size, :, :]
|
||||
|
||||
# generate max 3 tokens
|
||||
max_length = input_ids.shape[-1] + 3
|
||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
||||
return config, input_ids, None, max_length
|
||||
|
||||
# not implemented currently
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Training is not yet supported")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
def test_generate_with_head_masking(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("fp16 is not yet supported for TF models")
|
||||
def test_generate_fp16(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs()
|
||||
config.max_target_positions = 400
|
||||
input_features = input_dict["input_features"]
|
||||
model = TFWhisperForConditionalGeneration(config)
|
||||
model.generate(input_features)
|
||||
model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = [
|
||||
"input_features",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
expected_arg_names.extend(
|
||||
["decoder_position_ids", "head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
|
||||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
if hasattr(self.model_tester, "encoder_seq_length"):
|
||||
seq_length = self.model_tester.encoder_seq_length
|
||||
else:
|
||||
seq_length = self.model_tester.seq_length
|
||||
|
||||
subsampled_seq_length = model._get_feat_extract_output_lengths(seq_length)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[subsampled_seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
|
||||
self.assertIsInstance(hidden_states, (list, tuple))
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[decoder_seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
decoder_key_length = getattr(self.model_tester, "decoder_key_length", encoder_key_length)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
|
||||
subsampled_encoder_seq_length = model._get_feat_extract_output_lengths(encoder_seq_length)
|
||||
subsampled_encoder_key_length = model._get_feat_extract_output_lengths(encoder_key_length)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
correct_outlen = 5
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
decoder_seq_length,
|
||||
subsampled_encoder_key_length,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
added_hidden_states = 2
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
|
||||
)
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_encoder_outputs(
|
||||
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
|
||||
):
|
||||
encoder = model.get_encoder()
|
||||
encoder_outputs = encoder(
|
||||
input_ids,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave(
|
||||
num_interleave, dim=0
|
||||
)
|
||||
input_ids = input_ids[:, :, 0]
|
||||
input_ids = tf.zeros_like(input_ids[:, :1], dtype=tf.int64) + tf.convert_to_tensor(
|
||||
[model._get_decoder_start_token_id()]
|
||||
)
|
||||
attention_mask = None
|
||||
return encoder_outputs, input_ids, attention_mask
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, mel, seq_length = input_ids.shape
|
||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
gen_len = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||
)
|
||||
|
||||
# scores
|
||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
||||
|
||||
# Attentions
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(
|
||||
output.encoder_attentions, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Hidden States
|
||||
# encoder
|
||||
self._check_encoder_hidden_states_for_generate(
|
||||
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
|
||||
# decoder
|
||||
self._check_hidden_states_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_hidden_states,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_features = inputs_dict.get("input_features", None)
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_features
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# generating multiple sequences when no beam search generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||
output_tokens = model.generate(
|
||||
input_features, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_features = inputs_dict.get("input_features", None)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_ids, num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(
|
||||
model.generate(
|
||||
input_features,
|
||||
do_sample=True,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
)
|
||||
# num_return_sequences > 1, greedy
|
||||
self._check_generated_ids(
|
||||
model.generate(input_features, do_sample=False, num_beams=2, num_return_sequences=2)
|
||||
)
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||
output_tokens = model.generate(
|
||||
input_features, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_tokenizers
|
||||
class TFWhisperModelIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return WhisperProcessor.from_pretrained("openai/whisper-base")
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
@slow
|
||||
def test_tiny_logits_librispeech(self):
|
||||
set_seed(0)
|
||||
model = TFWhisperModel.from_pretrained("openai/whisper-tiny")
|
||||
input_speech = self._load_datasamples(1)
|
||||
feature_extractor = WhisperFeatureExtractor()
|
||||
input_features = feature_extractor(input_speech, return_tensors="tf").input_features
|
||||
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=tf.convert_to_tensor([[50258, 50259, 50359]]),
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=False,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
2.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407,
|
||||
0.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246,
|
||||
4.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713,
|
||||
0.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_GENERATION = tf.convert_to_tensor(
|
||||
[
|
||||
-1.4651, -2.6944, 2.7821, 2.3793, 4.0738, 0.0188, -3.3203, 1.9836,
|
||||
0.0520, 0.7095, 1.1063, 0.2952, -3.6786, -0.5249, 0.3105, 4.7691,
|
||||
1.1562, 1.3046, 0.5810, -0.3624, 1.7006, 1.3424, 0.9817, 2.1958,
|
||||
1.8775, -5.7046, -0.7679, 4.0113, 2.6848, 2.8609
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
head_logits = logits[0] @ tf.transpose(model.model.decoder.embed_tokens.weights[0])
|
||||
self.assertTrue(np.allclose(head_logits[0, 0, :30], EXPECTED_GENERATION, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_small_en_logits_librispeech(self):
|
||||
set_seed(0)
|
||||
model = TFWhisperModel.from_pretrained("openai/whisper-small.en")
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
feaure_extractor = WhisperFeatureExtractor()
|
||||
input_features = feaure_extractor(input_speech, return_tensors="tf").input_features
|
||||
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=tf.convert_to_tensor([[model.config.decoder_start_token_id]]),
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0])
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
-3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188,
|
||||
-8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935,
|
||||
-6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781,
|
||||
-10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509,
|
||||
-11.1146, -8.1918
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_large_logits_librispeech(self):
|
||||
set_seed(0)
|
||||
|
||||
model = TFWhisperModel.from_pretrained("openai/whisper-large")
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
processed_inputs = processor(audio=input_speech, text="This part of the speech", return_tensors="tf")
|
||||
input_features = processed_inputs.input_features
|
||||
labels = processed_inputs.labels
|
||||
|
||||
logits = model(
|
||||
input_features,
|
||||
decoder_input_ids=labels,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
logits = logits.last_hidden_state @ tf.transpose(model.model.decoder.embed_tokens.weights[0])
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472,
|
||||
1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357,
|
||||
1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376,
|
||||
1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4))
|
||||
|
||||
@slow
|
||||
def test_tiny_en_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
model.config.decoder_start_token_id = 50257
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5)
|
||||
transcript = processor.tokenizer.batch_decode(generated_ids)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = (
|
||||
"<|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle"
|
||||
" classes, and we are glad to"
|
||||
)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5)
|
||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||
|
||||
EXPECTED_TRANSCRIPT = (
|
||||
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
|
||||
" classes and we are glad"
|
||||
)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_xla_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
|
||||
generated_ids = model.generate(input_features, num_beams=5)
|
||||
generated_ids_xla = xla_generate(input_features, num_beams=5)
|
||||
|
||||
transcript = processor.tokenizer.decode(generated_ids[0])
|
||||
transcript_xla = processor.tokenizer.decode(generated_ids_xla[0])
|
||||
|
||||
EXPECTED_TRANSCRIPT = (
|
||||
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
|
||||
" classes and we are glad"
|
||||
)
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
self.assertEqual(transcript_xla, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_large_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_large_generation_multilingual(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
ds = load_dataset("common_voice", "ja", split="test", streaming=True)
|
||||
ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
|
||||
input_speech = next(iter(ds))["audio"]["array"]
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe")
|
||||
generated_ids = model.generate(input_features, do_sample=False)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe")
|
||||
generated_ids = model.generate(
|
||||
input_features,
|
||||
do_sample=False,
|
||||
)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " Kimura san ni denwa wo kaite moraimashita"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate")
|
||||
generated_ids = model.generate(input_features, do_sample=False)
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||
|
||||
EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san"
|
||||
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_large_batched_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-large")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
generated_ids = model.generate(input_features)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
[50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
|
||||
[50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
|
||||
[50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
|
||||
[50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to',
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all,"
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_en_batched_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
generated_ids = model.generate(input_features)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
[50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],
|
||||
[50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],
|
||||
[50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],
|
||||
[50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]
|
||||
]
|
||||
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
" Mr. Quilter is the apostle of the middle classes, and we are glad to",
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef looming",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
|
||||
@slow
|
||||
def test_tiny_en_batched_xla_generation(self):
|
||||
set_seed(0)
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="tf").input_features
|
||||
|
||||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||||
|
||||
generated_ids = model.generate(input_features)
|
||||
generated_ids_xla = xla_generate(input_features)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_LOGITS = tf.convert_to_tensor(
|
||||
[
|
||||
[50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284],
|
||||
[50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256],
|
||||
[50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236],
|
||||
[50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460]
|
||||
]
|
||||
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
|
||||
self.assertTrue(np.allclose(generated_ids_xla, EXPECTED_LOGITS))
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
" Mr. Quilter is the apostle of the middle classes, and we are glad to",
|
||||
" Nor is Mr. Quilter's manner less interesting than his matter.",
|
||||
" He tells us that at this festive season of the year, with Christmas and roast beef looming",
|
||||
" He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can",
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
transcript_xla = processor.batch_decode(generated_ids_xla, skip_special_tokens=True)
|
||||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT)
|
||||
self.assertListEqual(transcript_xla, EXPECTED_TRANSCRIPT)
|
@ -736,6 +736,23 @@ class TFModelTesterMixin:
|
||||
dtype="float32",
|
||||
),
|
||||
}
|
||||
elif model_class.__name__ in ["TFWhisperModel", "TFWhisperForConditionalGeneration"]:
|
||||
inputs = {
|
||||
"decoder_input_ids": tf.keras.Input(
|
||||
batch_shape=(2, max_input),
|
||||
name="decoder_input_ids",
|
||||
dtype="int32",
|
||||
),
|
||||
"input_features": tf.keras.Input(
|
||||
batch_shape=(
|
||||
2,
|
||||
self.model_tester.num_mel_bins,
|
||||
self.model_tester.seq_length,
|
||||
),
|
||||
name="input_features",
|
||||
dtype="float32",
|
||||
),
|
||||
}
|
||||
elif self.is_encoder_decoder:
|
||||
inputs = {
|
||||
"decoder_input_ids": tf.keras.Input(
|
||||
@ -1223,8 +1240,17 @@ class TFModelTesterMixin:
|
||||
|
||||
# fetch the output for an input exclusively made of new members of the vocabulary
|
||||
inputs_dict = copy.deepcopy(original_inputs_dict)
|
||||
new_vocab_input_ids = ids_tensor(inputs_dict["input_ids"].shape, new_tokens_size)
|
||||
ids_feat_name = None
|
||||
if "input_ids" in inputs_dict:
|
||||
ids_feat_name = "input_ids"
|
||||
elif "decoder_input_ids" in inputs_dict:
|
||||
ids_feat_name = "decoder_input_ids"
|
||||
else:
|
||||
assert False, "No input ids feature found in the inputs dict"
|
||||
|
||||
new_vocab_input_ids = ids_tensor(inputs_dict[ids_feat_name].shape, new_tokens_size)
|
||||
new_vocab_input_ids += old_total_size
|
||||
inputs_dict[ids_feat_name] = new_vocab_input_ids
|
||||
if "input_ids" in inputs_dict:
|
||||
inputs_dict["input_ids"] = new_vocab_input_ids
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
|
@ -105,6 +105,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
||||
"TFRobertaForMultipleChoice", # TODO: fix
|
||||
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"TFWhisperEncoder", # Building part of bigger (tested) model.
|
||||
"TFWhisperDecoder", # Building part of bigger (tested) model.
|
||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
|
||||
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
|
||||
|
@ -97,4 +97,5 @@ src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
|
||||
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
|
||||
src/transformers/models/wavlm/modeling_wavlm.py
|
||||
src/transformers/models/whisper/modeling_whisper.py
|
||||
src/transformers/models/whisper/modeling_tf_whisper.py
|
||||
src/transformers/models/yolos/modeling_yolos.py
|
||||
|
Loading…
Reference in New Issue
Block a user