# coding=utf-8 # Copyright 2021 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Testing suite for the PyTorch Speech2Text model. """ import copy import inspect import os import tempfile import unittest from transformers.file_utils import cached_property from transformers.testing_utils import ( is_torch_available, require_sentencepiece, require_tokenizers, require_torch, require_torchaudio, slow, torch_device, ) from .test_configuration_common import ConfigTester from .test_generation_utils import GenerationTesterMixin from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor if is_torch_available(): import torch from transformers import ( Speech2TextConfig, Speech2TextForConditionalGeneration, Speech2TextModel, Speech2TextProcessor, ) from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, Speech2TextEncoder def prepare_speech_to_text_inputs_dict( config, input_features, decoder_input_ids, attention_mask=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_features.ne(0) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) if head_mask is None: head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) if cross_attn_head_mask is None: cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { # "input_ids": input_features, "input_features": input_features, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, } @require_torch class Speech2TextModelTester: def __init__( self, parent, batch_size=13, seq_length=7, is_training=True, use_labels=False, vocab_size=99, hidden_size=16, num_hidden_layers=2, num_attention_heads=4, intermediate_size=4, num_conv_layers=2, conv_kernel_sizes=(5, 5), conv_channels=32, input_feat_per_channel=24, input_channels=1, hidden_act="relu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=20, max_source_positions=20, max_target_positions=20, eos_token_id=2, pad_token_id=1, bos_token_id=0, ): self.parent = parent self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.num_conv_layers = num_conv_layers self.conv_kernel_sizes = conv_kernel_sizes self.conv_channels = conv_channels self.input_feat_per_channel = input_feat_per_channel self.input_channels = input_channels self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.max_source_positions = max_source_positions self.max_target_positions = max_target_positions self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id def prepare_config_and_inputs(self): input_features = floats_tensor( [self.batch_size, self.seq_length, self.input_feat_per_channel], self.vocab_size ) attention_mask = torch.ones([self.batch_size, self.seq_length], dtype=torch.long, device=torch_device) decoder_input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(2) config = Speech2TextConfig( vocab_size=self.vocab_size, d_model=self.hidden_size, encoder_layers=self.num_hidden_layers, decoder_layers=self.num_hidden_layers, encoder_attention_heads=self.num_attention_heads, decoder_attention_heads=self.num_attention_heads, encoder_ffn_dim=self.intermediate_size, decoder_ffn_dim=self.intermediate_size, num_conv_layers=self.num_conv_layers, conv_kernel_sizes=self.conv_kernel_sizes, conv_channels=self.conv_channels, input_feat_per_channel=self.input_feat_per_channel, input_channels=self.input_channels, dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, max_position_embeddings=self.max_position_embeddings, max_source_positions=self.max_source_positions, max_target_positions=self.max_target_positions, eos_token_id=self.eos_token_id, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, ) inputs_dict = prepare_speech_to_text_inputs_dict( config, input_features=input_features, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, ) return config, inputs_dict def prepare_config_and_inputs_for_common(self): config, inputs_dict = self.prepare_config_and_inputs() return config, inputs_dict def get_subsampled_output_lengths(self, input_lengths): """ Computes the output length of the convolutional layers """ for i in range(self.num_conv_layers): input_lengths = (input_lengths - 1) // 2 + 1 return input_lengths def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): model = Speech2TextModel(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["decoder_input_ids"] attention_mask = inputs_dict["decoder_attention_mask"] # first forward pass outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) output, past_key_values = outputs.to_tuple() # create hypothetical multiple next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size).clamp(2) next_attn_mask = ids_tensor((self.batch_size, 3), 2) # append to next input_ids and next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_attention_mask = torch.cat([attention_mask, next_attn_mask], dim=-1) output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[ "last_hidden_state" ] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) def check_encoder_decoder_model_standalone(self, config, inputs_dict): model = Speech2TextModel(config=config).to(torch_device).eval() outputs = model(**inputs_dict) encoder_last_hidden_state = outputs.encoder_last_hidden_state last_hidden_state = outputs.last_hidden_state with tempfile.TemporaryDirectory() as tmpdirname: encoder = model.get_encoder() encoder.save_pretrained(tmpdirname) encoder = Speech2TextEncoder.from_pretrained(tmpdirname).to(torch_device) encoder_last_hidden_state_2 = encoder( inputs_dict["input_features"], attention_mask=inputs_dict["attention_mask"] )[0] self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) with tempfile.TemporaryDirectory() as tmpdirname: decoder = model.get_decoder() decoder.save_pretrained(tmpdirname) decoder = Speech2TextDecoder.from_pretrained(tmpdirname).to(torch_device) last_hidden_state_2 = decoder( input_ids=inputs_dict["decoder_input_ids"], attention_mask=inputs_dict["decoder_attention_mask"], encoder_hidden_states=encoder_last_hidden_state, encoder_attention_mask=inputs_dict["attention_mask"], )[0] self.parent.assertTrue((last_hidden_state_2 - last_hidden_state).abs().max().item() < 1e-3) @require_torch class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): all_model_classes = (Speech2TextModel, Speech2TextForConditionalGeneration) if is_torch_available() else () all_generative_model_classes = (Speech2TextForConditionalGeneration,) if is_torch_available() else () is_encoder_decoder = True test_pruning = False test_missing_keys = False test_torchscript = True input_name = "input_features" def setUp(self): self.model_tester = Speech2TextModelTester(self) self.config_tester = ConfigTester(self, config_class=Speech2TextConfig) self.maxDiff = 3000 def test_config(self): self.config_tester.run_common_tests() def test_save_load_strict(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs() for model_class in self.all_model_classes: model = model_class(config) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True) self.assertEqual(info["missing_keys"], []) def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) def test_encoder_decoder_model_standalone(self): config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) def test_inputs_embeds(self): pass # training is not supported yet def test_training(self): pass def test_training_gradient_checkpointing(self): pass def test_generate_fp16(self): config, input_dict = self.model_tester.prepare_config_and_inputs() input_features = input_dict["input_features"] attention_mask = input_dict["attention_mask"] model = Speech2TextForConditionalGeneration(config).eval().to(torch_device) if torch_device == "cuda": input_features = input_features.half() model.half() model.generate(input_features, attention_mask=attention_mask) model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) signature = inspect.signature(model.forward) # signature.parameters is an OrderedDict => so arg_names order is deterministic arg_names = [*signature.parameters.keys()] expected_arg_names = [ "input_features", "attention_mask", "decoder_input_ids", "decoder_attention_mask", ] expected_arg_names.extend( ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names else ["encoder_outputs"] ) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): model = model_class(config) model.to(torch_device) model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states expected_num_layers = getattr( self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1 ) self.assertEqual(len(hidden_states), expected_num_layers) if hasattr(self.model_tester, "encoder_seq_length"): seq_length = self.model_tester.encoder_seq_length else: seq_length = self.model_tester.seq_length subsampled_seq_length = model._get_subsampled_output_lengths(seq_length) self.assertListEqual( list(hidden_states[0].shape[-2:]), [subsampled_seq_length, self.model_tester.hidden_size], ) if config.is_encoder_decoder: hidden_states = outputs.decoder_hidden_states self.assertIsInstance(hidden_states, (list, tuple)) self.assertEqual(len(hidden_states), expected_num_layers) seq_len = getattr(self.model_tester, "seq_length", None) decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) self.assertListEqual( list(hidden_states[0].shape[-2:]), [decoder_seq_length, self.model_tester.hidden_size], ) config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: inputs_dict["output_hidden_states"] = True check_hidden_states_output(inputs_dict, config, model_class) # check that output_hidden_states also work using config del inputs_dict["output_hidden_states"] config.output_hidden_states = True check_hidden_states_output(inputs_dict, config, model_class) def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True seq_len = getattr(self.model_tester, "seq_length", None) decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True inputs_dict["output_hidden_states"] = False config.return_dict = True model = model_class(config) model.to(torch_device) model.eval() subsampled_encoder_seq_length = model._get_subsampled_output_lengths(encoder_seq_length) subsampled_encoder_key_length = model._get_subsampled_output_lengths(encoder_key_length) with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) # check that output_attentions also work using config del inputs_dict["output_attentions"] config.output_attentions = True model = model_class(config) model.to(torch_device) model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) self.assertListEqual( list(attentions[0].shape[-3:]), [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length], ) out_len = len(outputs) correct_outlen = 5 # loss is at first position if "labels" in inputs_dict: correct_outlen += 1 # loss is added to beginning if "past_key_values" in outputs: correct_outlen += 1 # past_key_values have been returned self.assertEqual(out_len, correct_outlen) # decoder attentions decoder_attentions = outputs.decoder_attentions self.assertIsInstance(decoder_attentions, (list, tuple)) self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( list(decoder_attentions[0].shape[-3:]), [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], ) # cross attentions cross_attentions = outputs.cross_attentions self.assertIsInstance(cross_attentions, (list, tuple)) self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( list(cross_attentions[0].shape[-3:]), [ self.model_tester.num_attention_heads, decoder_seq_length, subsampled_encoder_key_length, ], ) # Check attention is always last and order is fine inputs_dict["output_attentions"] = True inputs_dict["output_hidden_states"] = True model = model_class(config) model.to(torch_device) model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) added_hidden_states = 2 self.assertEqual(out_len + added_hidden_states, len(outputs)) self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( list(self_attentions[0].shape[-3:]), [self.model_tester.num_attention_heads, subsampled_encoder_seq_length, subsampled_encoder_key_length], ) def test_resize_tokens_embeddings(self): ( original_config, inputs_dict, ) = self.model_tester.prepare_config_and_inputs_for_common() if not self.test_resize_embeddings: return for model_class in self.all_model_classes: config = copy.deepcopy(original_config) model = model_class(config) model.to(torch_device) if self.model_tester.is_training is False: model.eval() model_vocab_size = config.vocab_size # Retrieve the embeddings and clone theme model_embed = model.resize_token_embeddings(model_vocab_size) cloned_embeddings = model_embed.weight.clone() # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size model_embed = model.resize_token_embeddings(model_vocab_size + 10) self.assertEqual(model.config.vocab_size, model_vocab_size + 10) # Check that it actually resizes the embeddings matrix self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class)) # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size model_embed = model.resize_token_embeddings(model_vocab_size - 15) self.assertEqual(model.config.vocab_size, model_vocab_size - 15) # Check that it actually resizes the embeddings matrix self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15) # make sure that decoder_input_ids are resized if "decoder_input_ids" in inputs_dict: inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) model(**self._prepare_for_class(inputs_dict, model_class)) # Check that adding and removing tokens has not modified the first part of the embedding matrix. models_equal = True for p1, p2 in zip(cloned_embeddings, model_embed.weight): if p1.data.ne(p2.data).sum() > 0: models_equal = False self.assertTrue(models_equal) def test_resize_embeddings_untied(self): ( original_config, inputs_dict, ) = self.model_tester.prepare_config_and_inputs_for_common() if not self.test_resize_embeddings: return original_config.tie_word_embeddings = False # if model cannot untied embeddings -> leave test if original_config.tie_word_embeddings: return for model_class in self.all_model_classes: config = copy.deepcopy(original_config) model = model_class(config).to(torch_device) # if no output embeddings -> leave test if model.get_output_embeddings() is None: continue # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size model_vocab_size = config.vocab_size model.resize_token_embeddings(model_vocab_size + 10) self.assertEqual(model.config.vocab_size, model_vocab_size + 10) output_embeds = model.get_output_embeddings() self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10) # Check bias if present if output_embeds.bias is not None: self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10) # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class)) # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size model.resize_token_embeddings(model_vocab_size - 15) self.assertEqual(model.config.vocab_size, model_vocab_size - 15) # Check that it actually resizes the embeddings matrix output_embeds = model.get_output_embeddings() self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15) # Check bias if present if output_embeds.bias is not None: self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15) # Check that the model can still do a forward pass successfully (every parameter should be resized) if "decoder_input_ids" in inputs_dict: inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class)) def test_generate_without_input_ids(self): pass @staticmethod def _get_encoder_outputs( model, input_ids, attention_mask, output_attentions=None, output_hidden_states=None, num_interleave=1 ): encoder = model.get_encoder() encoder_outputs = encoder( input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.repeat_interleave( num_interleave, dim=0 ) input_ids = input_ids[:, :, 0] input_ids = torch.zeros_like(input_ids[:, :1], dtype=torch.long) + model._get_decoder_start_token_id() attention_mask = None return encoder_outputs, input_ids, attention_mask def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape[:2] subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length) num_sequences_in_output = batch_size * num_return_sequences gen_len = ( output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length ) # scores self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config) # Attentions # encoder self._check_encoder_attention_for_generate( output.encoder_attentions, batch_size, config, subsampled_seq_length ) # decoder self._check_attentions_for_generate( num_sequences_in_output, output.decoder_attentions, min_length=1, max_length=output.sequences.shape[-1], config=config, use_cache=use_cache, ) # Hidden States # encoder self._check_encoder_hidden_states_for_generate( output.encoder_hidden_states, batch_size, config, subsampled_seq_length ) # decoder self._check_hidden_states_for_generate( num_sequences_in_output, output.decoder_hidden_states, min_length=1, max_length=output.sequences.shape[-1], config=config, use_cache=use_cache, ) def _create_and_check_torchscript(self, config, inputs_dict): if not self.test_torchscript: return configs_no_init = _config_zero_init(config) # To be sure we have no Nan configs_no_init.torchscript = True for model_class in self.all_model_classes: model = model_class(config=configs_no_init) model.to(torch_device) model.eval() inputs = self._prepare_for_class(inputs_dict, model_class) try: model.config.use_cache = False # FSTM still requires this hack -> FSTM should probably be refactored similar to BART afterward input_features = inputs["input_features"] attention_mask = inputs["attention_mask"] decoder_input_ids = inputs["decoder_input_ids"] decoder_attention_mask = inputs["decoder_attention_mask"] traced_model = torch.jit.trace( model, (input_features, attention_mask, decoder_input_ids, decoder_attention_mask) ) except RuntimeError: self.fail("Couldn't trace module.") with tempfile.TemporaryDirectory() as tmp_dir_name: pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") try: torch.jit.save(traced_model, pt_file_name) except Exception: self.fail("Couldn't save module.") try: loaded_model = torch.jit.load(pt_file_name) except Exception: self.fail("Couldn't load module.") model.to(torch_device) model.eval() loaded_model.to(torch_device) loaded_model.eval() model_state_dict = model.state_dict() loaded_model_state_dict = loaded_model.state_dict() self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) models_equal = True for layer_name, p1 in model_state_dict.items(): p2 = loaded_model_state_dict[layer_name] if p1.data.ne(p2.data).sum() > 0: models_equal = False self.assertTrue(models_equal) @require_torch @require_torchaudio @require_sentencepiece @require_tokenizers @slow class Speech2TextModelIntegrationTests(unittest.TestCase): @cached_property def default_processor(self): return Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr") def _load_datasamples(self, num_samples): from datasets import load_dataset import soundfile as sf # map files to raw def map_to_array(batch): speech, _ = sf.read(batch["file"]) batch["speech"] = speech return batch ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") ds = ds.select(range(num_samples)).map(map_to_array) return ds["speech"][:num_samples] def test_generation_librispeech(self): model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") model.to(torch_device) processor = self.default_processor input_speech = self._load_datasamples(1) input_features = processor(input_speech, return_tensors="pt").input_features.to(torch_device) generated_ids = model.generate(input_features) generated_transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"] self.assertListEqual(generated_transcript, EXPECTED_TRANSCRIPTIONS) def test_generation_librispeech_batched(self): model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr") model.to(torch_device) processor = self.default_processor input_speech = self._load_datasamples(4) inputs = processor(input_speech, return_tensors="pt", padding=True) input_features = inputs.input_features.to(torch_device) attention_mask = inputs.attention_mask.to(torch_device) generated_ids = model.generate(input_features, attention_mask=attention_mask) generated_transcripts = processor.batch_decode(generated_ids, skip_special_tokens=True) EXPECTED_TRANSCRIPTIONS = [ "a man said to the universe sir i exist", "sweat covered brion's body trickling into the titleing cloth that was the only garment he wore", "the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about", "his instant of panic was followed by a small sharp blow high on his chest", ] self.assertListEqual(generated_transcripts, EXPECTED_TRANSCRIPTIONS)