mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
is_batched fix for remaining 2-D numpy arrays (#23309)
* Fix is_batched code to allow 2-D numpy arrays for audio * Tests * Fix typo * Incorporate comments from PR #23223
This commit is contained in:
parent
6b7d6f848b
commit
3d57404464
@ -135,7 +135,8 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
|
||||
Args:
|
||||
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
||||
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values.
|
||||
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
|
||||
stereo, i.e. single float per timestep.
|
||||
sampling_rate (`int`, *optional*):
|
||||
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
|
||||
`sampling_rate` at the forward call to prevent silent errors.
|
||||
@ -160,9 +161,11 @@ class ASTFeatureExtractor(SequenceFeatureExtractor):
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_speech, (list, tuple))
|
||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
||||
if is_batched_numpy and len(raw_speech.shape) > 2:
|
||||
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
|
||||
is_batched = is_batched_numpy or (
|
||||
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
|
||||
if is_batched:
|
||||
|
@ -272,7 +272,8 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
|
||||
Args:
|
||||
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
||||
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values.
|
||||
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
|
||||
stereo, i.e. single float per timestep.
|
||||
truncation (`str`, *optional*):
|
||||
Truncation pattern for long audio inputs. Two patterns are available:
|
||||
- `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and
|
||||
@ -312,9 +313,11 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_speech, (list, tuple))
|
||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
||||
if is_batched_numpy and len(raw_speech.shape) > 2:
|
||||
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
|
||||
is_batched = is_batched_numpy or (
|
||||
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
|
||||
if is_batched:
|
||||
|
@ -180,7 +180,8 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
|
||||
Args:
|
||||
raw_speech (`torch.Tensor`, `np.ndarray`, `List[float]`, `List[torch.Tensor]`, `List[np.ndarray]`, `List[List[float]]`):
|
||||
The sequence or batch of sequences to be padded. Each sequence can be a tensor, a numpy array, a list
|
||||
of float values, a list of tensors, a list of numpy arrays or a list of list of float values.
|
||||
of float values, a list of tensors, a list of numpy arrays or a list of list of float values. Must be
|
||||
mono channel audio, not stereo, i.e. single float per timestep.
|
||||
padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
@ -231,9 +232,11 @@ class MCTCTFeatureExtractor(SequenceFeatureExtractor):
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_speech, (list, tuple))
|
||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
||||
if is_batched_numpy and len(raw_speech.shape) > 2:
|
||||
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
|
||||
is_batched = is_batched_numpy or (
|
||||
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
|
||||
if is_batched:
|
||||
|
@ -141,7 +141,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
Args:
|
||||
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
||||
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values.
|
||||
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
|
||||
stereo, i.e. single float per timestep.
|
||||
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||
Select a strategy to pad the returned sequences (according to the model's padding side and padding
|
||||
index) among:
|
||||
@ -200,9 +201,11 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_speech, (list, tuple))
|
||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
||||
if is_batched_numpy and len(raw_speech.shape) > 2:
|
||||
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
|
||||
is_batched = is_batched_numpy or (
|
||||
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
|
||||
if is_batched:
|
||||
|
@ -201,7 +201,8 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
Args:
|
||||
audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, *optional*):
|
||||
The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values. This outputs waveform features.
|
||||
values, a list of numpy arrays or a list of list of float values. This outputs waveform features. Must
|
||||
be mono channel audio, not stereo, i.e. single float per timestep.
|
||||
audio_target (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, *optional*):
|
||||
The sequence or batch of sequences to be processed as targets. Each sequence can be a numpy array, a
|
||||
list of float values, a list of numpy arrays or a list of list of float values. This outputs log-mel
|
||||
@ -307,9 +308,11 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
is_batched = bool(
|
||||
isinstance(speech, (list, tuple))
|
||||
and (isinstance(speech[0], np.ndarray) or isinstance(speech[0], (tuple, list)))
|
||||
is_batched_numpy = isinstance(speech, np.ndarray) and len(speech.shape) > 1
|
||||
if is_batched_numpy and len(speech.shape) > 2:
|
||||
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
|
||||
is_batched = is_batched_numpy or (
|
||||
isinstance(speech, (list, tuple)) and (isinstance(speech[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
|
||||
if is_batched:
|
||||
|
@ -129,7 +129,8 @@ class TvltFeatureExtractor(SequenceFeatureExtractor):
|
||||
Args:
|
||||
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
||||
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values.
|
||||
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
|
||||
stereo, i.e. single float per timestep.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors instead of list of python integers. Acceptable values are:
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
@ -176,9 +177,11 @@ class TvltFeatureExtractor(SequenceFeatureExtractor):
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_speech, (list, tuple))
|
||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
||||
if is_batched_numpy and len(raw_speech.shape) > 2:
|
||||
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
|
||||
is_batched = is_batched_numpy or (
|
||||
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
if is_batched:
|
||||
raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
|
||||
|
@ -152,7 +152,8 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
Args:
|
||||
raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
|
||||
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
|
||||
values, a list of numpy arrays or a list of list of float values.
|
||||
values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
|
||||
stereo, i.e. single float per timestep.
|
||||
truncation (`bool`, *optional*, default to `True`):
|
||||
Activates truncation to cut input sequences longer than *max_length* to *max_length*.
|
||||
pad_to_multiple_of (`int`, *optional*, defaults to None):
|
||||
@ -203,9 +204,11 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
|
||||
"Failing to do so can result in silent errors that might be hard to debug."
|
||||
)
|
||||
|
||||
is_batched = bool(
|
||||
isinstance(raw_speech, (list, tuple))
|
||||
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
|
||||
is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
|
||||
if is_batched_numpy and len(raw_speech.shape) > 2:
|
||||
raise ValueError(f"Only mono-channel audio is supported for input to {self}")
|
||||
is_batched = is_batched_numpy or (
|
||||
isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
|
||||
)
|
||||
|
||||
if is_batched:
|
||||
|
@ -125,6 +125,14 @@ class ASTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Test
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
# Test 2-D numpy arrays are batched.
|
||||
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
|
||||
np_speech_inputs = np.asarray(speech_inputs)
|
||||
encoded_sequences_1 = feat_extract(speech_inputs, return_tensors="np").input_values
|
||||
encoded_sequences_2 = feat_extract(np_speech_inputs, return_tensors="np").input_values
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
@require_torch
|
||||
def test_double_precision_pad(self):
|
||||
import torch
|
||||
|
@ -139,6 +139,14 @@ class ClapFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
# Test 2-D numpy arrays are batched.
|
||||
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
|
||||
np_speech_inputs = np.asarray(speech_inputs)
|
||||
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
|
||||
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_double_precision_pad(self):
|
||||
import torch
|
||||
|
||||
|
@ -134,6 +134,14 @@ class MCTCTFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Te
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
# Test 2-D numpy arrays are batched.
|
||||
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
|
||||
np_speech_inputs = np.asarray(speech_inputs)
|
||||
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
|
||||
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_cepstral_mean_and_variance_normalization(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(8000, 14000, 2000)]
|
||||
|
@ -136,6 +136,14 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
# Test 2-D numpy arrays are batched.
|
||||
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
|
||||
np_speech_inputs = np.asarray(speech_inputs)
|
||||
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
|
||||
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_cepstral_mean_and_variance_normalization(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
|
@ -275,6 +275,14 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
# Test 2-D numpy arrays are batched.
|
||||
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
|
||||
np_speech_inputs = np.asarray(speech_inputs)
|
||||
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_values
|
||||
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_values
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
def test_batch_feature_target(self):
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_target()
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
|
@ -189,6 +189,15 @@ class TvltFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.Tes
|
||||
self.assertTrue(encoded_audios.shape[-2] <= feature_extractor.spectrogram_length)
|
||||
self.assertTrue(encoded_audios.shape[-3] == feature_extractor.num_channels)
|
||||
|
||||
# Test 2-D numpy arrays are batched.
|
||||
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
|
||||
np_speech_inputs = np.asarray(speech_inputs)
|
||||
encoded_audios = feature_extractor(np_speech_inputs, return_tensors="np", sampling_rate=44100).audio_values
|
||||
self.assertTrue(encoded_audios.ndim == 4)
|
||||
self.assertTrue(encoded_audios.shape[-1] == feature_extractor.feature_size)
|
||||
self.assertTrue(encoded_audios.shape[-2] <= feature_extractor.spectrogram_length)
|
||||
self.assertTrue(encoded_audios.shape[-3] == feature_extractor.num_channels)
|
||||
|
||||
def _load_datasamples(self, num_samples):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
# automatic decoding with librispeech
|
||||
|
@ -173,6 +173,14 @@ class WhisperFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
# Test 2-D numpy arrays are batched.
|
||||
speech_inputs = [floats_list((1, x))[0] for x in (800, 800, 800)]
|
||||
np_speech_inputs = np.asarray(speech_inputs)
|
||||
encoded_sequences_1 = feature_extractor(speech_inputs, return_tensors="np").input_features
|
||||
encoded_sequences_2 = feature_extractor(np_speech_inputs, return_tensors="np").input_features
|
||||
for enc_seq_1, enc_seq_2 in zip(encoded_sequences_1, encoded_sequences_2):
|
||||
self.assertTrue(np.allclose(enc_seq_1, enc_seq_2, atol=1e-3))
|
||||
|
||||
# Test truncation required
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(200, (feature_extractor.n_samples + 500), 200)]
|
||||
np_speech_inputs = [np.asarray(speech_input) for speech_input in speech_inputs]
|
||||
|
Loading…
Reference in New Issue
Block a user