mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add TFSpeech2Text (#15113)
* Add wrapper classes * convert inner layers to tf * Add TF Encoder and Decoder layers * TFSpeech2Text models * Loadable model * TF model with same outputs as PT model * test skeleton * correct tests and run the fixup * correct attention expansion * TFSpeech2Text pask_key_values with TF format
This commit is contained in:
parent
6a5472a8e1
commit
8406fa6dd5
@ -227,7 +227,7 @@ Flax), PyTorch, and/or TensorFlow.
|
||||
| SEW | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| SEW-D | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Speech Encoder decoder | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Speech2Text | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| Speech2Text | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| Speech2Text2 | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| Splinter | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
@ -202,6 +202,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its
|
||||
|
||||
[[autodoc]] TFAutoModelForVision2Seq
|
||||
|
||||
## TFAutoModelForSpeechSeq2Seq
|
||||
|
||||
[[autodoc]] TFAutoModelForSpeechSeq2Seq
|
||||
|
||||
## FlaxAutoModel
|
||||
|
||||
[[autodoc]] FlaxAutoModel
|
||||
|
@ -144,3 +144,13 @@ See the [model hub](https://huggingface.co/models?filter=speech_to_text) to look
|
||||
|
||||
[[autodoc]] Speech2TextForConditionalGeneration
|
||||
- forward
|
||||
|
||||
## TFSpeech2TextModel
|
||||
|
||||
[[autodoc]] TFSpeech2TextModel
|
||||
- call
|
||||
|
||||
## TFSpeech2TextForConditionalGeneration
|
||||
|
||||
[[autodoc]] TFSpeech2TextForConditionalGeneration
|
||||
- call
|
||||
|
@ -1621,6 +1621,7 @@ if is_tf_available():
|
||||
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
||||
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
@ -1635,6 +1636,7 @@ if is_tf_available():
|
||||
"TFAutoModelForQuestionAnswering",
|
||||
"TFAutoModelForSeq2SeqLM",
|
||||
"TFAutoModelForSequenceClassification",
|
||||
"TFAutoModelForSpeechSeq2Seq",
|
||||
"TFAutoModelForTableQuestionAnswering",
|
||||
"TFAutoModelForTokenClassification",
|
||||
"TFAutoModelForVision2Seq",
|
||||
@ -1946,6 +1948,14 @@ if is_tf_available():
|
||||
"TFRoFormerPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.speech_to_text"].extend(
|
||||
[
|
||||
"TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFSpeech2TextForConditionalGeneration",
|
||||
"TFSpeech2TextModel",
|
||||
"TFSpeech2TextPreTrainedModel",
|
||||
]
|
||||
)
|
||||
_import_structure["models.t5"].extend(
|
||||
[
|
||||
"TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -3588,6 +3598,7 @@ if TYPE_CHECKING:
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
@ -3602,6 +3613,7 @@ if TYPE_CHECKING:
|
||||
TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelForSpeechSeq2Seq,
|
||||
TFAutoModelForTableQuestionAnswering,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelForVision2Seq,
|
||||
@ -3850,6 +3862,12 @@ if TYPE_CHECKING:
|
||||
TFRoFormerModel,
|
||||
TFRoFormerPreTrainedModel,
|
||||
)
|
||||
from .models.speech_to_text import (
|
||||
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFSpeech2TextForConditionalGeneration,
|
||||
TFSpeech2TextModel,
|
||||
TFSpeech2TextPreTrainedModel,
|
||||
)
|
||||
from .models.t5 import (
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFT5EncoderModel,
|
||||
|
@ -394,9 +394,12 @@ class TFGenerationMixin:
|
||||
|
||||
Parameters:
|
||||
|
||||
input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*):
|
||||
The sequence used as a prompt for the generation. If `None` the method initializes it with
|
||||
`bos_token_id` and a batch size of 1.
|
||||
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, `(batch_size, sequence_length,
|
||||
feature_dim)` or `(batch_size, num_channels, height, width)`, *optional*):
|
||||
The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
|
||||
method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
|
||||
should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
|
||||
`input_ids`, `input_values`, `input_features`, or `pixel_values`.
|
||||
max_length (`int`, *optional*, defaults to 20):
|
||||
The maximum length of the sequence to be generated.
|
||||
min_length (`int`, *optional*, defaults to 10):
|
||||
@ -657,11 +660,12 @@ class TFGenerationMixin:
|
||||
), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
|
||||
|
||||
# create attention mask if necessary
|
||||
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
|
||||
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
|
||||
elif attention_mask is None:
|
||||
attention_mask = tf.ones_like(input_ids)
|
||||
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.call).parameters.keys())
|
||||
if accepts_attention_mask:
|
||||
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
|
||||
attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
|
||||
elif attention_mask is None:
|
||||
attention_mask = tf.ones(shape_list(input_ids)[:2], dtype=tf.int32)
|
||||
|
||||
if pad_token_id is None and eos_token_id is not None:
|
||||
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
|
||||
@ -697,16 +701,12 @@ class TFGenerationMixin:
|
||||
encoder = self.get_encoder()
|
||||
|
||||
encoder_kwargs = {
|
||||
"attention_mask": attention_mask,
|
||||
"output_attentions": output_attentions,
|
||||
"output_hidden_states": output_hidden_states,
|
||||
"return_dict": return_dict_in_generate,
|
||||
}
|
||||
|
||||
# vision models don't use `attention_mask`.
|
||||
signature = dict(inspect.signature(encoder.call).parameters)
|
||||
if "attention_mask" not in signature:
|
||||
encoder_kwargs.pop("attention_mask")
|
||||
if accepts_attention_mask:
|
||||
encoder_kwargs["attention_mask"] = attention_mask
|
||||
|
||||
encoder_outputs = encoder(input_ids, **encoder_kwargs)
|
||||
if return_dict_in_generate:
|
||||
@ -715,23 +715,15 @@ class TFGenerationMixin:
|
||||
if output_hidden_states:
|
||||
model_kwargs["encoder_hidden_states"] = encoder_outputs.hidden_states
|
||||
|
||||
# The condition `len(shape_list(input_ids)) == 2` is to make this block treats only text inputs.
|
||||
# (vision inputs might occur when the model is an encoder-decoder model)
|
||||
# Expand input ids if num_beams > 1 or num_return_sequences > 1
|
||||
if len(shape_list(input_ids)) == 2 and (num_return_sequences > 1 or num_beams > 1):
|
||||
input_ids_len = shape_list(input_ids)[-1]
|
||||
input_ids = tf.broadcast_to(
|
||||
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
|
||||
)
|
||||
attention_mask = tf.broadcast_to(
|
||||
tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
|
||||
)
|
||||
input_ids = tf.reshape(
|
||||
input_ids, (effective_batch_size * num_beams, input_ids_len)
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
attention_mask = tf.reshape(
|
||||
attention_mask, (effective_batch_size * num_beams, input_ids_len)
|
||||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
|
||||
expanded_batch_idxs = tf.reshape(
|
||||
tf.repeat(tf.expand_dims(tf.range(batch_size), -1), repeats=num_beams * effective_batch_mult, axis=1),
|
||||
shape=(-1,),
|
||||
)
|
||||
# prepares text-based inputs
|
||||
if len(shape_list(input_ids)) == 2:
|
||||
input_ids = tf.gather(input_ids, expanded_batch_idxs, axis=0)
|
||||
if accepts_attention_mask:
|
||||
attention_mask = tf.gather(attention_mask, expanded_batch_idxs, axis=0)
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
|
||||
@ -749,11 +741,6 @@ class TFGenerationMixin:
|
||||
batch_size == encoder_outputs[0].shape[0]
|
||||
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
|
||||
|
||||
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
|
||||
expanded_batch_idxs = tf.reshape(
|
||||
tf.repeat(tf.expand_dims(tf.range(batch_size), -1), repeats=num_beams * effective_batch_mult, axis=1),
|
||||
shape=(-1,),
|
||||
)
|
||||
# expand encoder_outputs
|
||||
encoder_outputs = (tf.gather(encoder_outputs[0], expanded_batch_idxs, axis=0),)
|
||||
else:
|
||||
@ -851,7 +838,8 @@ class TFGenerationMixin:
|
||||
unfinished_sents = tf.ones_like(input_ids[:, 0])
|
||||
sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
|
||||
|
||||
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
|
||||
# defined for encoder-decoder models, None for decoder-only models
|
||||
past = encoder_outputs
|
||||
|
||||
# init attention / hidden states / scores tuples
|
||||
scores = () if (return_dict_in_generate and kwargs["output_scores"]) else None
|
||||
@ -871,7 +859,11 @@ class TFGenerationMixin:
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
|
||||
input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
@ -1132,7 +1124,11 @@ class TFGenerationMixin:
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
|
||||
input_ids,
|
||||
past=past,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
outputs = self(
|
||||
**model_inputs,
|
||||
|
@ -35,6 +35,7 @@ class TransposeType(ExplicitEnum):
|
||||
|
||||
NO = "no"
|
||||
SIMPLE = "simple"
|
||||
CONV1D = "conv1d"
|
||||
CONV2D = "conv2d"
|
||||
|
||||
|
||||
@ -68,8 +69,9 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
|
||||
|
||||
# When should we transpose the weights
|
||||
if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4:
|
||||
# A simple heuristic to detect conv layer using weight array shape
|
||||
transpose = TransposeType.CONV2D
|
||||
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 3:
|
||||
transpose = TransposeType.CONV1D
|
||||
elif bool(
|
||||
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
|
||||
or "emb_projs" in tf_name
|
||||
@ -194,7 +196,6 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
# authorized missing keys don't have to be loaded
|
||||
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
|
||||
continue
|
||||
|
||||
raise AttributeError(f"{name} not found in PyTorch model")
|
||||
|
||||
array = pt_state_dict[name].numpy()
|
||||
@ -204,6 +205,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
|
||||
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
|
||||
array = numpy.transpose(array, axes=(2, 3, 1, 0))
|
||||
elif transpose is TransposeType.CONV1D:
|
||||
# Conv1D weight:
|
||||
# PT: (num_out_channel, num_in_channel, kernel)
|
||||
# -> TF: (kernel, num_in_channel, num_out_channel)
|
||||
array = numpy.transpose(array, axes=(2, 1, 0))
|
||||
elif transpose is TransposeType.SIMPLE:
|
||||
array = numpy.transpose(array)
|
||||
|
||||
@ -355,7 +361,6 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
all_tf_weights = set(list(tf_weights_map.keys()))
|
||||
loaded_pt_weights_data_ptr = {}
|
||||
missing_keys_pt = []
|
||||
|
||||
for pt_weight_name, pt_weight in current_pt_params_dict.items():
|
||||
# Handle PyTorch shared weight ()not duplicated in TF 2.0
|
||||
if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
|
||||
@ -377,6 +382,11 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
# TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
|
||||
# -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
|
||||
array = numpy.transpose(array, axes=(3, 2, 0, 1))
|
||||
elif transpose is TransposeType.CONV1D:
|
||||
# Conv1D weight:
|
||||
# TF: (kernel, num_in_channel, num_out_channel)
|
||||
# -> PT: (num_out_channel, num_in_channel, kernel)
|
||||
array = numpy.transpose(array, axes=(2, 1, 0))
|
||||
elif transpose is TransposeType.SIMPLE:
|
||||
array = numpy.transpose(array)
|
||||
|
||||
|
@ -87,6 +87,7 @@ if is_tf_available():
|
||||
"TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING",
|
||||
"TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
|
||||
"TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
|
||||
"TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
|
||||
"TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
|
||||
"TF_MODEL_FOR_VISION_2_SEQ_MAPPING",
|
||||
@ -101,6 +102,7 @@ if is_tf_available():
|
||||
"TFAutoModelForQuestionAnswering",
|
||||
"TFAutoModelForSeq2SeqLM",
|
||||
"TFAutoModelForSequenceClassification",
|
||||
"TFAutoModelForSpeechSeq2Seq",
|
||||
"TFAutoModelForTableQuestionAnswering",
|
||||
"TFAutoModelForTokenClassification",
|
||||
"TFAutoModelForVision2Seq",
|
||||
@ -201,6 +203,7 @@ if TYPE_CHECKING:
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
|
||||
@ -215,6 +218,7 @@ if TYPE_CHECKING:
|
||||
TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelForSeq2SeqLM,
|
||||
TFAutoModelForSequenceClassification,
|
||||
TFAutoModelForSpeechSeq2Seq,
|
||||
TFAutoModelForTableQuestionAnswering,
|
||||
TFAutoModelForTokenClassification,
|
||||
TFAutoModelForVision2Seq,
|
||||
|
@ -801,7 +801,7 @@ class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
||||
|
||||
|
||||
AutoModelForSpeechSeq2Seq = auto_class_update(
|
||||
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing"
|
||||
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
|
||||
)
|
||||
|
||||
|
||||
|
@ -29,6 +29,7 @@ logger = logging.get_logger(__name__)
|
||||
TF_MODEL_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Base model mapping
|
||||
("speech_to_text", "TFSpeech2TextModel"),
|
||||
("clip", "TFCLIPModel"),
|
||||
("deberta-v2", "TFDebertaV2Model"),
|
||||
("deberta", "TFDebertaModel"),
|
||||
@ -103,6 +104,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model with LM heads mapping
|
||||
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
||||
("rembert", "TFRemBertForMaskedLM"),
|
||||
("roformer", "TFRoFormerForMaskedLM"),
|
||||
("convbert", "TFConvBertForMaskedLM"),
|
||||
@ -204,6 +206,12 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
|
||||
]
|
||||
)
|
||||
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
|
||||
[
|
||||
# Model for Sequence Classification mapping
|
||||
@ -340,6 +348,9 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
||||
)
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
|
||||
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
||||
)
|
||||
@ -468,6 +479,15 @@ TFAutoModelForNextSentencePrediction = auto_class_update(
|
||||
)
|
||||
|
||||
|
||||
class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
||||
_model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
||||
|
||||
|
||||
TFAutoModelForSpeechSeq2Seq = auto_class_update(
|
||||
TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
|
||||
)
|
||||
|
||||
|
||||
class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
|
@ -147,7 +147,11 @@ class TFBartAttention(tf.keras.layers.Layer):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = tf.keras.layers.Dropout(dropout)
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
@ -296,11 +300,11 @@ class TFBartEncoderLayer(tf.keras.layers.Layer):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (`tf.Tensor`): attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
*(encoder_attention_heads,)*
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
@ -372,17 +376,17 @@ class TFBartDecoderLayer(tf.keras.layers.Layer):
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (`tf.Tensor`): attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
encoder_hidden_states (`tf.Tensor`):
|
||||
cross attention input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
*(decoder_attention_heads,)*
|
||||
`(decoder_attention_heads,)`
|
||||
cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module.
|
||||
*(decoder_attention_heads,)*
|
||||
`(decoder_attention_heads,)`
|
||||
past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
@ -150,7 +150,11 @@ class TFBlenderbotAttention(tf.keras.layers.Layer):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = tf.keras.layers.Dropout(dropout)
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
|
@ -149,7 +149,11 @@ class TFBlenderbotSmallAttention(tf.keras.layers.Layer):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = tf.keras.layers.Dropout(dropout)
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
@ -299,11 +303,11 @@ class TFBlenderbotSmallEncoderLayer(tf.keras.layers.Layer):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (`tf.Tensor`): attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
*(encoder_attention_heads,)*
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
@ -376,17 +380,17 @@ class TFBlenderbotSmallDecoderLayer(tf.keras.layers.Layer):
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (`tf.Tensor`): attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
encoder_hidden_states (`tf.Tensor`):
|
||||
cross attention input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
*(decoder_attention_heads,)*
|
||||
`(decoder_attention_heads,)`
|
||||
cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module.
|
||||
*(decoder_attention_heads,)*
|
||||
`(decoder_attention_heads,)`
|
||||
past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
@ -736,7 +736,11 @@ class TFHubertAttention(tf.keras.layers.Layer):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = tf.keras.layers.Dropout(dropout)
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
|
@ -189,7 +189,11 @@ class TFMarianAttention(tf.keras.layers.Layer):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = tf.keras.layers.Dropout(dropout)
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
@ -339,11 +343,11 @@ class TFMarianEncoderLayer(tf.keras.layers.Layer):
|
||||
def call(self, hidden_states: tf.Tensor, attention_mask: tf.Tensor, layer_head_mask: tf.Tensor, training=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (`tf.Tensor`): attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
*(encoder_attention_heads,)*
|
||||
`(encoder_attention_heads,)`
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states, self_attn_weights, _ = self.self_attn(
|
||||
@ -416,17 +420,17 @@ class TFMarianDecoderLayer(tf.keras.layers.Layer):
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
hidden_states (`tf.Tensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
attention_mask (`tf.Tensor`): attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
encoder_hidden_states (`tf.Tensor`):
|
||||
cross attention input to the layer of shape *(seq_len, batch, embed_dim)*
|
||||
cross attention input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
|
||||
*(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
|
||||
*(decoder_attention_heads,)*
|
||||
`(decoder_attention_heads,)`
|
||||
cross_attn_layer_head_mask (`tf.Tensor`): mask for heads of the cross-attention module.
|
||||
*(decoder_attention_heads,)*
|
||||
`(decoder_attention_heads,)`
|
||||
past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
@ -149,7 +149,11 @@ class TFMBartAttention(tf.keras.layers.Layer):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = tf.keras.layers.Dropout(dropout)
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
|
@ -190,7 +190,11 @@ class TFPegasusAttention(tf.keras.layers.Layer):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = tf.keras.layers.Dropout(dropout)
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
|
@ -17,7 +17,13 @@
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _LazyModule, is_sentencepiece_available, is_speech_available, is_torch_available
|
||||
from ...file_utils import (
|
||||
_LazyModule,
|
||||
is_sentencepiece_available,
|
||||
is_speech_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {
|
||||
@ -36,6 +42,14 @@ if is_speech_available():
|
||||
if is_sentencepiece_available():
|
||||
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
|
||||
|
||||
if is_tf_available():
|
||||
_import_structure["modeling_tf_speech_to_text"] = [
|
||||
"TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TFSpeech2TextForConditionalGeneration",
|
||||
"TFSpeech2TextModel",
|
||||
"TFSpeech2TextPreTrainedModel",
|
||||
]
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_speech_to_text"] = [
|
||||
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
@ -57,6 +71,14 @@ if TYPE_CHECKING:
|
||||
if is_sentencepiece_available():
|
||||
from .processing_speech_to_text import Speech2TextProcessor
|
||||
|
||||
if is_tf_available():
|
||||
from .modeling_tf_speech_to_text import (
|
||||
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFSpeech2TextForConditionalGeneration,
|
||||
TFSpeech2TextModel,
|
||||
TFSpeech2TextPreTrainedModel,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_speech_to_text import (
|
||||
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
|
1615
src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
Executable file
1615
src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -765,7 +765,11 @@ class TFWav2Vec2Attention(tf.keras.layers.Layer):
|
||||
self.num_heads = num_heads
|
||||
self.dropout = tf.keras.layers.Dropout(dropout)
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
|
@ -198,6 +198,9 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = None
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None
|
||||
|
||||
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = None
|
||||
|
||||
|
||||
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = None
|
||||
|
||||
|
||||
@ -276,6 +279,13 @@ class TFAutoModelForSequenceClassification(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFAutoModelForSpeechSeq2Seq(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFAutoModelForTableQuestionAnswering(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
@ -1678,6 +1688,30 @@ class TFRoFormerPreTrainedModel(metaclass=DummyObject):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class TFSpeech2TextForConditionalGeneration(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSpeech2TextModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
class TFSpeech2TextPreTrainedModel(metaclass=DummyObject):
|
||||
_backends = ["tf"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["tf"])
|
||||
|
||||
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
@ -1478,6 +1478,8 @@ class ModelTesterMixin:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
elif key == "input_features":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
|
||||
|
||||
@ -1529,6 +1531,8 @@ class ModelTesterMixin:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
elif key == "input_features":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
|
||||
|
||||
|
@ -57,6 +57,7 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
BertConfig,
|
||||
TFAutoModel,
|
||||
@ -140,6 +141,7 @@ class TFModelTesterMixin:
|
||||
*get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_PRETRAINING_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
|
||||
*get_values(TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING),
|
||||
]:
|
||||
inputs_dict["labels"] = tf.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
|
||||
@ -358,7 +360,6 @@ class TFModelTesterMixin:
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class)
|
||||
)
|
||||
@ -374,6 +375,8 @@ class TFModelTesterMixin:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
elif name == "pixel_values":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
elif name == "input_features":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
|
||||
@ -416,6 +419,8 @@ class TFModelTesterMixin:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
elif name == "pixel_values":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
elif name == "input_features":
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
||||
else:
|
||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||
|
||||
@ -443,7 +448,24 @@ class TFModelTesterMixin:
|
||||
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if self.is_encoder_decoder:
|
||||
if model_class.__name__ in ["TFSpeech2TextModel", "TFSpeech2TextForConditionalGeneration"]:
|
||||
inputs = {
|
||||
"decoder_input_ids": tf.keras.Input(
|
||||
batch_shape=(2, max_input),
|
||||
name="decoder_input_ids",
|
||||
dtype="int32",
|
||||
),
|
||||
"input_features": tf.keras.Input(
|
||||
batch_shape=(
|
||||
2,
|
||||
max_input,
|
||||
self.model_tester.input_feat_per_channel * self.model_tester.input_channels,
|
||||
),
|
||||
name="input_features",
|
||||
dtype="float32",
|
||||
),
|
||||
}
|
||||
elif self.is_encoder_decoder:
|
||||
inputs = {
|
||||
"decoder_input_ids": tf.keras.Input(
|
||||
batch_shape=(2, max_input),
|
||||
@ -511,10 +533,7 @@ class TFModelTesterMixin:
|
||||
outputs_dict = model(inputs)
|
||||
|
||||
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
input_ids = inputs_keywords.pop("input_ids", None)
|
||||
if input_ids is None:
|
||||
input_ids = inputs_keywords.pop("pixel_values", None)
|
||||
outputs_keywords = model(input_ids, **inputs_keywords)
|
||||
outputs_keywords = model(**inputs_keywords)
|
||||
output_dict = outputs_dict[0].numpy()
|
||||
output_keywords = outputs_keywords[0].numpy()
|
||||
|
||||
@ -699,23 +718,28 @@ class TFModelTesterMixin:
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
list_lm_models = (
|
||||
text_in_text_out_models = (
|
||||
get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING)
|
||||
+ get_values(TF_MODEL_FOR_MASKED_LM_MAPPING)
|
||||
+ get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING)
|
||||
)
|
||||
speech_in_text_out_models = get_values(TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
|
||||
if model_class in list_lm_models:
|
||||
if model_class in text_in_text_out_models:
|
||||
x = model.get_output_embeddings()
|
||||
assert isinstance(x, tf.keras.layers.Layer)
|
||||
name = model.get_bias()
|
||||
assert isinstance(name, dict)
|
||||
for k, v in name.items():
|
||||
assert isinstance(v, tf.Variable)
|
||||
elif model_class in speech_in_text_out_models:
|
||||
x = model.get_output_embeddings()
|
||||
assert isinstance(x, tf.keras.layers.Layer)
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
else:
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None
|
||||
@ -922,13 +946,13 @@ class TFModelTesterMixin:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined mobel needs input_ids
|
||||
# if bos token id is not defined model needs input_ids
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True))
|
||||
else:
|
||||
# num_return_sequences = 1
|
||||
elif model_class.__name__ not in ["TFSpeech2TextForConditionalGeneration"]:
|
||||
# Models with non-text inputs won't work here; num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(do_sample=True, max_length=5))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
@ -952,6 +976,8 @@ class TFModelTesterMixin:
|
||||
def test_lm_head_model_no_beam_search_generate_dict_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get("input_ids", None)
|
||||
if input_ids is None:
|
||||
input_ids = inputs_dict.get("input_features", None)
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@ -988,7 +1014,7 @@ class TFModelTesterMixin:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
|
||||
# if bos token id is not defined model needs input_ids, num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2))
|
||||
else:
|
||||
# num_return_sequences = 1
|
||||
@ -1023,6 +1049,8 @@ class TFModelTesterMixin:
|
||||
def test_lm_head_model_beam_search_generate_dict_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get("input_ids", None)
|
||||
if input_ids is None:
|
||||
input_ids = inputs_dict.get("input_features", None)
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@ -1072,10 +1100,11 @@ class TFModelTesterMixin:
|
||||
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
input_name = "input_ids" if "input_ids" in prepared_for_class else "pixel_values"
|
||||
input_ids = prepared_for_class.pop(input_name)
|
||||
possible_input_names = {"input_ids", "pixel_values", "input_features"}
|
||||
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
|
||||
model_input = prepared_for_class.pop(input_name)
|
||||
|
||||
loss = model(input_ids, **prepared_for_class)[0]
|
||||
loss = model(model_input, **prepared_for_class)[0]
|
||||
self.assertEqual(loss.shape, [loss_size])
|
||||
|
||||
# Test that model correctly compute the loss with a dict
|
||||
|
605
tests/test_modeling_tf_speech_to_text.py
Normal file
605
tests/test_modeling_tf_speech_to_text.py
Normal file
@ -0,0 +1,605 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Testing suite for the TensorFlow Speech2Text model. """
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import Speech2TextConfig
|
||||
from transformers.file_utils import cached_property, is_tf_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import Speech2TextProcessor, TFSpeech2TextForConditionalGeneration, TFSpeech2TextModel
|
||||
|
||||
|
||||
def prepare_speech_to_text_inputs_dict(
|
||||
config,
|
||||
input_features,
|
||||
decoder_input_ids,
|
||||
attention_mask=None,
|
||||
decoder_attention_mask=None,
|
||||
head_mask=None,
|
||||
decoder_head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
):
|
||||
if attention_mask is None:
|
||||
attention_mask = tf.math.not_equal(input_features, 0)
|
||||
if decoder_attention_mask is None:
|
||||
decoder_attention_mask = tf.math.not_equal(decoder_input_ids, config.pad_token_id)
|
||||
if head_mask is None:
|
||||
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
|
||||
if decoder_head_mask is None:
|
||||
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
if cross_attn_head_mask is None:
|
||||
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
|
||||
return {
|
||||
"input_features": input_features,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": attention_mask,
|
||||
"head_mask": head_mask,
|
||||
"decoder_head_mask": decoder_head_mask,
|
||||
"cross_attn_head_mask": cross_attn_head_mask,
|
||||
}
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFSpeech2TextModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_labels=False,
|
||||
vocab_size=99,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=4,
|
||||
num_conv_layers=2,
|
||||
conv_kernel_sizes=(5, 5),
|
||||
conv_channels=32,
|
||||
input_feat_per_channel=24,
|
||||
input_channels=1,
|
||||
hidden_act="relu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
max_source_positions=20,
|
||||
max_target_positions=20,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.use_labels = use_labels
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.conv_kernel_sizes = conv_kernel_sizes
|
||||
self.conv_channels = conv_channels
|
||||
self.input_feat_per_channel = input_feat_per_channel
|
||||
self.input_channels = input_channels
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.max_source_positions = max_source_positions
|
||||
self.max_target_positions = max_target_positions
|
||||
self.eos_token_id = eos_token_id
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_features = floats_tensor(
|
||||
[self.batch_size, self.seq_length, self.input_feat_per_channel], self.vocab_size
|
||||
)
|
||||
attention_mask = tf.ones([self.batch_size, self.seq_length], dtype=tf.int64)
|
||||
decoder_input_ids = tf.math.maximum(ids_tensor([self.batch_size, self.seq_length], self.vocab_size), 2)
|
||||
|
||||
config = self.get_config()
|
||||
inputs_dict = prepare_speech_to_text_inputs_dict(
|
||||
config,
|
||||
input_features=input_features,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
return config, inputs_dict
|
||||
|
||||
def get_config(self):
|
||||
return Speech2TextConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
d_model=self.hidden_size,
|
||||
encoder_layers=self.num_hidden_layers,
|
||||
decoder_layers=self.num_hidden_layers,
|
||||
encoder_attention_heads=self.num_attention_heads,
|
||||
decoder_attention_heads=self.num_attention_heads,
|
||||
encoder_ffn_dim=self.intermediate_size,
|
||||
decoder_ffn_dim=self.intermediate_size,
|
||||
num_conv_layers=self.num_conv_layers,
|
||||
conv_kernel_sizes=self.conv_kernel_sizes,
|
||||
conv_channels=self.conv_channels,
|
||||
input_feat_per_channel=self.input_feat_per_channel,
|
||||
input_channels=self.input_channels,
|
||||
dropout=self.hidden_dropout_prob,
|
||||
attention_dropout=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
max_source_positions=self.max_source_positions,
|
||||
max_target_positions=self.max_target_positions,
|
||||
eos_token_id=self.eos_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
pad_token_id=self.pad_token_id,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config, inputs_dict = self.prepare_config_and_inputs()
|
||||
return config, inputs_dict
|
||||
|
||||
def get_subsampled_output_lengths(self, input_lengths):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
|
||||
for _ in range(self.num_conv_layers):
|
||||
input_lengths = (input_lengths - 1) // 2 + 1
|
||||
|
||||
return input_lengths
|
||||
|
||||
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
|
||||
model = TFSpeech2TextModel(config=config).get_decoder()
|
||||
input_ids = inputs_dict["decoder_input_ids"]
|
||||
attention_mask = inputs_dict["decoder_attention_mask"]
|
||||
|
||||
# first forward pass
|
||||
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
|
||||
|
||||
_, (_, past_key_values) = outputs.to_tuple()
|
||||
|
||||
# create hypothetical multiple next token and extent to next_input_ids
|
||||
next_tokens = tf.math.maximum(ids_tensor((self.batch_size, 3), config.vocab_size), 2)
|
||||
next_attn_mask = ids_tensor((self.batch_size, 3), 2, dtype=tf.int64)
|
||||
|
||||
# append to next input_ids and
|
||||
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
|
||||
next_attention_mask = tf.concat([attention_mask, next_attn_mask], axis=-1)
|
||||
|
||||
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
|
||||
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
|
||||
"last_hidden_state"
|
||||
]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
|
||||
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx]
|
||||
output_from_past_slice = output_from_past[:, :, random_slice_idx]
|
||||
|
||||
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
|
||||
|
||||
# test that outputs are equal for slice
|
||||
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, atol=1e-2)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFSpeech2TextModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (TFSpeech2TextModel, TFSpeech2TextForConditionalGeneration) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFSpeech2TextForConditionalGeneration,) if is_tf_available() else ()
|
||||
is_encoder_decoder = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
test_onnx = False
|
||||
|
||||
input_name = "input_ids"
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFSpeech2TextModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=Speech2TextConfig)
|
||||
self.maxDiff = 3000
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_decoder_model_past_with_large_inputs(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
|
||||
|
||||
# not implemented currently
|
||||
def test_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
# training is not supported yet
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
def test_training_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
def test_generate_fp16(self):
|
||||
pass
|
||||
|
||||
def test_hidden_states_output(self):
|
||||
def check_hidden_states_output(inputs_dict, config, model_class):
|
||||
model = model_class(config)
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
|
||||
|
||||
expected_num_layers = getattr(
|
||||
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
|
||||
)
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
|
||||
if hasattr(self.model_tester, "encoder_seq_length"):
|
||||
seq_length = self.model_tester.encoder_seq_length
|
||||
else:
|
||||
seq_length = self.model_tester.seq_length
|
||||
|
||||
subsampled_seq_length = model._get_feat_extract_output_lengths(seq_length)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[subsampled_seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
if config.is_encoder_decoder:
|
||||
hidden_states = outputs.decoder_hidden_states
|
||||
|
||||
self.assertIsInstance(hidden_states, (list, tuple))
|
||||
self.assertEqual(len(hidden_states), expected_num_layers)
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]),
|
||||
[decoder_seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
# check that output_hidden_states also work using config
|
||||
del inputs_dict["output_hidden_states"]
|
||||
config.output_hidden_states = True
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def test_attention_outputs(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
|
||||
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
|
||||
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = False
|
||||
config.return_dict = True
|
||||
model = model_class(config)
|
||||
|
||||
subsampled_encoder_seq_length = model._get_feat_extract_output_lengths(encoder_seq_length)
|
||||
subsampled_encoder_key_length = model._get_feat_extract_output_lengths(encoder_key_length)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
# check that output_attentions also work using config
|
||||
del inputs_dict["output_attentions"]
|
||||
config.output_attentions = True
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
|
||||
)
|
||||
out_len = len(outputs)
|
||||
|
||||
correct_outlen = 5
|
||||
|
||||
# loss is at first position
|
||||
if "labels" in inputs_dict:
|
||||
correct_outlen += 1 # loss is added to beginning
|
||||
if "past_key_values" in outputs:
|
||||
correct_outlen += 1 # past_key_values have been returned
|
||||
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
# decoder attentions
|
||||
decoder_attentions = outputs.decoder_attentions
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
|
||||
)
|
||||
|
||||
# cross attentions
|
||||
cross_attentions = outputs.cross_attentions
|
||||
self.assertIsInstance(cross_attentions, (list, tuple))
|
||||
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(cross_attentions[0].shape[-3:]),
|
||||
[
|
||||
self.model_tester.num_attention_heads,
|
||||
decoder_seq_length,
|
||||
subsampled_encoder_key_length,
|
||||
],
|
||||
)
|
||||
|
||||
# Check attention is always last and order is fine
|
||||
inputs_dict["output_attentions"] = True
|
||||
inputs_dict["output_hidden_states"] = True
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
added_hidden_states = 2
|
||||
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
||||
|
||||
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
|
||||
|
||||
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(self_attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length],
|
||||
)
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
# Overwritten method from parent; see `test_resize_embeddings_untied`
|
||||
pass
|
||||
|
||||
def test_resize_tokens_embeddings(self):
|
||||
# see `test_resize_embeddings_untied`
|
||||
pass
|
||||
|
||||
def test_resize_embeddings_untied(self):
|
||||
# TODO: copy test from PT. Not working at the moment because the test relies on `model.resize_token_embeddings`,
|
||||
# whose TF implementation assumes the use of `TFWrappedEmbeddings`. But with a `TFWrappedEmbeddings` we can't
|
||||
# load the weights from PT (also, it induces TF1 behavior, so we might want to rework how
|
||||
# `model.resize_token_embeddings` operates).
|
||||
pass
|
||||
|
||||
def test_generate_without_input_ids(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_encoder_outputs(
|
||||
model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1
|
||||
):
|
||||
encoder = model.get_encoder()
|
||||
encoder_outputs = encoder(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
encoder_outputs["last_hidden_state"] = tf.repeat(encoder_outputs.last_hidden_state, num_interleave, axis=0)
|
||||
|
||||
input_ids = input_ids[:, :, 0]
|
||||
input_ids = tf.zeros_like(input_ids[:, :1], dtype=tf.int64) + model._get_decoder_start_token_id()
|
||||
attention_mask = None
|
||||
return encoder_outputs, input_ids, attention_mask
|
||||
|
||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
||||
batch_size, seq_length = input_ids.shape[:2]
|
||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
||||
num_sequences_in_output = batch_size * num_return_sequences
|
||||
gen_len = (
|
||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
||||
)
|
||||
|
||||
# scores
|
||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
||||
|
||||
# Attentions
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(
|
||||
output.encoder_attentions, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Hidden States
|
||||
# encoder
|
||||
self._check_encoder_hidden_states_for_generate(
|
||||
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
|
||||
)
|
||||
|
||||
# decoder
|
||||
self._check_hidden_states_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_hidden_states,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_features = inputs_dict.get("input_features", None)
|
||||
|
||||
# iterate over all generative models
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_features
|
||||
with self.assertRaises(AssertionError):
|
||||
model.generate(do_sample=True, max_length=5)
|
||||
# num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating multiple sequences when no beam search generation
|
||||
# is not allowed as it would always generate the same sequences
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_return_sequences=2))
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||
output_tokens = model.generate(
|
||||
input_features, do_sample=True, bad_words_ids=bad_words_ids, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||
# `input_features`
|
||||
def test_lm_head_model_random_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_features = inputs_dict.get("input_features", None)
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
if config.bos_token_id is None:
|
||||
# if bos token id is not defined model needs input_ids, num_return_sequences = 1
|
||||
self._check_generated_ids(model.generate(input_features, do_sample=True, num_beams=2))
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# generating more sequences than having beams leads is not possible
|
||||
model.generate(input_features, do_sample=False, num_return_sequences=3, num_beams=2)
|
||||
|
||||
# num_return_sequences > 1, sample
|
||||
self._check_generated_ids(
|
||||
model.generate(
|
||||
input_features,
|
||||
do_sample=True,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
)
|
||||
# num_return_sequences > 1, greedy
|
||||
self._check_generated_ids(
|
||||
model.generate(input_features, do_sample=False, num_beams=2, num_return_sequences=2)
|
||||
)
|
||||
|
||||
# check bad words tokens language generation
|
||||
# create list of 1-seq bad token and list of 2-seq of bad tokens
|
||||
bad_words_ids = [self._generate_random_bad_tokens(1, model), self._generate_random_bad_tokens(2, model)]
|
||||
output_tokens = model.generate(
|
||||
input_features, do_sample=False, bad_words_ids=bad_words_ids, num_beams=2, num_return_sequences=2
|
||||
)
|
||||
# only count generated tokens
|
||||
generated_ids = output_tokens[:, input_features.shape[-1] :]
|
||||
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
|
||||
|
||||
# overwritten from parent -- the input is `input_features`, not `input_ids`
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
signature = inspect.signature(model.call)
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = [
|
||||
"input_features",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
|
||||
@require_tf
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
@slow
|
||||
class TFSpeech2TextModelIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def default_processor(self):
|
||||
return Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"]
|
||||
|
||||
return [x["array"] for x in speech_samples]
|
||||
|
||||
def test_generation_librispeech(self):
|
||||
model = TFSpeech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
|
||||
processor = self.default_processor
|
||||
|
||||
input_speech = self._load_datasamples(1)
|
||||
|
||||
input_features = processor(input_speech, return_tensors="tf").input_features
|
||||
|
||||
generated_ids = model.generate(input_features)
|
||||
generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
|
||||
]
|
||||
self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_generation_librispeech_batched(self):
|
||||
model = TFSpeech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
|
||||
processor = self.default_processor
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="tf", padding=True)
|
||||
generated_ids = model.generate(inputs.input_features, attention_mask=inputs.attention_mask)
|
||||
generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
EXPECTED_TRANSCRIPTIONS = [
|
||||
"mister quilter is the apostle of the middle classes and we are glad to welcome his gospel",
|
||||
"nor is mister cultar's manner less interesting than his matter",
|
||||
"he tells us that at this festive season of the year with christmas and roast beef looming before us similes drawn from eating and its results occur most readily to the mind",
|
||||
"he has grave doubts whether sir frederick leyton's work is really greek after all and can discover in it but little of rocky ithaca",
|
||||
]
|
||||
self.assertListEqual(generated_transcripts, EXPECTED_TRANSCRIPTIONS)
|
Loading…
Reference in New Issue
Block a user