[Sequence Feature Extraction] Add truncation (#12804)

* fix_torch_device_generate_test

* remove @

* add truncate

* finish

* correct test

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* clean tests

* correct normalization for truncation

* remove casting

* up

* save intermed

* finish

* finish

* correct

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen 2021-07-23 17:53:30 +02:00 committed by GitHub
parent 98364ea74f
commit f6e254474c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 370 additions and 38 deletions

View File

@ -69,6 +69,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
],
padding: Union[bool, str, PaddingStrategy] = True,
max_length: Optional[int] = None,
truncation: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
@ -107,6 +108,8 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
truncation (:obj:`bool`):
Activates truncation to cut input sequences longer than :obj:`max_length` to :obj:`max_length`.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
@ -178,10 +181,18 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
processed_features[key] = to_py_obj(value)
# Convert padding_strategy in PaddingStrategy
padding_strategy, max_length, _ = self._get_padding_strategies(padding=padding, max_length=max_length)
padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
required_input = processed_features[self.model_input_names[0]]
if required_input and not isinstance(required_input[0], (list, tuple)):
# truncation
processed_features = self._truncate(
processed_features,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
truncation=truncation,
)
# padding
processed_features = self._pad(
processed_features,
max_length=max_length,
@ -196,13 +207,32 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
len(v) == batch_size for v in processed_features.values()
), "Some items in the output dictionary have a different batch size than others."
truncated_inputs = []
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in processed_features.items())
# truncation
inputs = self._truncate(
inputs,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
truncation=truncation,
)
truncated_inputs.append(inputs)
if padding_strategy == PaddingStrategy.LONGEST:
max_length = max(len(inputs) for inputs in required_input)
padding_strategy = PaddingStrategy.MAX_LENGTH
batch_outputs = {}
for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in processed_features.items())
# truncation
inputs = self._truncate(
truncated_inputs[i],
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
truncation=truncation,
)
# padding
outputs = self._pad(
inputs,
max_length=max_length,
@ -278,6 +308,46 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
return processed_features
def _truncate(
self,
processed_features: Union[Dict[str, List[float]], BatchFeature],
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
truncation: Optional[bool] = None,
):
"""
Pad inputs (on left/right and up to predefined length or max length in the batch)
Args:
processed_features: Dictionary of input values (`List[float]`) / input vectors (`List[List[float]]`) or batch of inputs values (`List[List[int]]`) / input vectors (`List[List[List[int]]]`)
max_length: maximum length of the returned list and optionally padding length (see below)
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
truncation: (optional) Activates truncation to cut input sequences longer than `max_length` to `max_length`.
"""
if not truncation:
return processed_features
elif truncation and max_length is None:
raise ValueError(
"When setting ``truncation=True``, make sure that ``max_length`` is defined and ``padding='max_length'``"
)
required_input = processed_features[self.model_input_names[0]]
# find `max_length` that fits `pad_to_multiple_of`
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_truncated = len(required_input) > max_length
if needs_to_be_truncated:
processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length]
if "attention_mask" in processed_features:
processed_features["attention_mask"] = processed_features["attention_mask"][:max_length]
return processed_features
def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs):
"""
Find the correct padding strategy
@ -308,4 +378,4 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
"Please select a value to use as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
)
return padding_strategy, max_length, kwargs
return padding_strategy

View File

@ -93,28 +93,37 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
@staticmethod
def utterance_cmvn(
x: np.ndarray, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True
x: np.ndarray, input_length: int, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True
) -> np.ndarray:
mean = x.mean(axis=0)
square_sums = (x ** 2).sum(axis=0)
# make sure we normalie float32 arrays
mean = x[:input_length].mean(axis=0)
square_sums = (x[:input_length] ** 2).sum(axis=0)
if normalize_means:
x = np.subtract(x, mean)
if normalize_vars:
var = square_sums / x.shape[0] - mean ** 2
var = square_sums / x[:input_length].shape[0] - mean ** 2
std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std)
# make sure array is in float32
x = x.astype(np.float32)
return x
def normalize(self, input_values: List[np.ndarray]) -> List[np.ndarray]:
return [self.utterance_cmvn(x, self.normalize_means, self.normalize_vars) for x in input_values]
def normalize(self, input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
return [
self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars)
for x, n in zip(input_values, input_lengths)
]
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
padding: Union[bool, str, PaddingStrategy] = False,
max_length: Optional[int] = None,
truncation: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
sampling_rate: Optional[int] = None,
@ -140,6 +149,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
truncation (:obj:`bool`):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
@ -191,6 +202,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
raw_speech = [np.asarray(speech) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
raw_speech = raw_speech.astype(np.float32)
# always return batch
if not is_batched:
@ -199,10 +212,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
# extract fbank features
features = [self._extract_fbank_features(waveform) for waveform in raw_speech]
# Utterance-level cepstral mean and variance normalization
if self.do_ceptral_normalize:
features = self.normalize(features)
# convert into correct format for padding
encoded_inputs = BatchFeature({"input_features": features})
@ -210,10 +219,29 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
encoded_inputs,
padding=padding,
max_length=max_length,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_tensors=return_tensors,
**kwargs,
)
if "attention_mask" in padded_inputs:
input_lengths = padded_inputs["attention_mask"].sum(-1)
else:
padded_input_values = padded_inputs["input_features"]
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
# Utterance-level cepstral mean and variance normalization
if self.do_ceptral_normalize:
input_features = padded_inputs["input_features"]
# make sure list is in array format
if isinstance(input_features[0], list):
input_features = [np.asarray(feature, dtype=np.float32) for feature in input_features]
padded_inputs["input_features"] = self.normalize(input_features, input_lengths=input_lengths)
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs

View File

@ -758,7 +758,6 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if attention_mask is not None:
attention_mask = self._get_subsampled_encoder_attn_mask(attention_mask)

View File

@ -79,17 +79,24 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
self.do_normalize = do_normalize
@staticmethod
def zero_mean_unit_var_norm(input_values: List[np.ndarray]) -> List[np.ndarray]:
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
"""
Every array in the list is normalized to have zero mean and unit variance
"""
return [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in input_values]
if isinstance(input_values[0], np.ndarray):
input_values = [x.astype(np.float32) for x in input_values]
normed_input_values = [
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths)
]
return normed_input_values
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
padding: Union[bool, str, PaddingStrategy] = False,
max_length: Optional[int] = None,
truncation: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
@ -115,6 +122,8 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
truncation (:obj:`bool`):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
@ -168,18 +177,16 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
# make sure input is in list format
if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech) for speech in raw_speech]
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech)
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
raw_speech = raw_speech.astype(np.float32)
# always return batch
if not is_batched:
raw_speech = [raw_speech]
# zero-mean and unit-variance normalization
if self.do_normalize:
raw_speech = self.zero_mean_unit_var_norm(raw_speech)
# convert into correct format for padding
encoded_inputs = BatchFeature({"input_values": raw_speech})
@ -187,9 +194,24 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
encoded_inputs,
padding=padding,
max_length=max_length,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_tensors=return_tensors,
)
if "attention_mask" in padded_inputs:
input_lengths = padded_inputs["attention_mask"].sum(-1)
else:
padded_input_values = padded_inputs["input_values"]
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
# zero-mean and unit-variance normalization
if self.do_normalize:
padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
padded_inputs["input_values"], input_lengths=input_lengths
)
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs

View File

@ -91,6 +91,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
if equal_length:
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
else:
# make sure that inputs increase in size
speech_inputs = [
floats_list((x, self.feature_size))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
@ -147,3 +148,26 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
_check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]])
def test_cepstral_mean_and_variance_normalization_trunc(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
inputs = feature_extractor(
speech_inputs,
padding="max_length",
max_length=4,
truncation=True,
return_tensors="np",
return_attention_mask=True,
)
input_features = inputs.input_features
attention_mask = inputs.attention_mask
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
_check_zero_mean_unit_variance(input_features[1])
_check_zero_mean_unit_variance(input_features[2])

View File

@ -83,6 +83,7 @@ class Wav2Vec2FeatureExtractionTester(unittest.TestCase):
if equal_length:
speech_inputs = floats_list((self.batch_size, self.max_seq_length))
else:
# make sure that inputs increase in size
speech_inputs = [
_flatten(floats_list((x, self.feature_size)))
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
@ -122,7 +123,7 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
def test_zero_mean_unit_variance_normalization(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract(speech_inputs, padding="longest")
processed = feat_extract(speech_inputs, padding="longest", return_tensors="np")
input_values = processed.input_values
def _check_zero_mean_unit_variance(input_vector):
@ -133,6 +134,22 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
_check_zero_mean_unit_variance(input_values[1, :1000])
_check_zero_mean_unit_variance(input_values[2])
def test_zero_mean_unit_variance_normalization_trunc(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
processed = feat_extract(
speech_inputs, truncation=True, max_length=1000, padding="max_length", return_tensors="np"
)
input_values = processed.input_values
def _check_zero_mean_unit_variance(input_vector):
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
_check_zero_mean_unit_variance(input_values[0, :800])
_check_zero_mean_unit_variance(input_values[1])
_check_zero_mean_unit_variance(input_values[2])
@slow
@require_torch
def test_pretrained_checkpoints_are_set_correctly(self):

View File

@ -715,7 +715,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
input_speech = self._load_datasamples(2)
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
@ -737,7 +737,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
input_speech = self._load_datasamples(4)
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
inputs = processor(input_speech, return_tensors="pt", padding=True)
input_values = inputs.input_values.to(torch_device)
attention_mask = inputs.attention_mask.to(torch_device)

View File

@ -126,12 +126,17 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
feature_size = self.feat_extract_tester.feature_size
# test padding for List[int] + numpy
input_1 = feat_extract.pad(processed_features, padding=False)[input_name]
input_2 = feat_extract.pad(processed_features, padding="longest")[input_name]
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))[
input_name
]
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
input_1 = feat_extract.pad(processed_features, padding=False)
input_1 = input_1[input_name]
input_2 = feat_extract.pad(processed_features, padding="longest")
input_2 = input_2[input_name]
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))
input_3 = input_3[input_name]
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")
input_4 = input_4[input_name]
# max_length parameter has to be provided when setting `padding="max_length"`
with self.assertRaises(ValueError):
@ -139,7 +144,8 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
input_5 = feat_extract.pad(
processed_features, padding="max_length", max_length=pad_max_length, return_tensors="np"
)[input_name]
)
input_5 = input_5[input_name]
self.assertFalse(_inputs_have_equal_length(input_1))
self.assertTrue(_inputs_have_equal_length(input_2))
@ -154,18 +160,25 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)
# test padding for `pad_to_multiple_of` for List[int] + numpy
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)[input_name]
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)[input_name]
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)
input_6 = input_6[input_name]
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)
input_7 = input_7[input_name]
input_8 = feat_extract.pad(
processed_features, padding="max_length", pad_to_multiple_of=10, max_length=pad_max_length
)[input_name]
)
input_8 = input_8[input_name]
input_9 = feat_extract.pad(
processed_features,
padding="max_length",
pad_to_multiple_of=10,
max_length=pad_max_length,
return_tensors="np",
)[input_name]
)
input_9 = input_9[input_name]
self.assertTrue(all(len(x) % 10 == 0 for x in input_6))
self.assertTrue(_inputs_are_equal(input_6, input_7))
@ -205,12 +218,149 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
< 1e-3
)
def _check_truncation(self, numpify=False):
def _inputs_have_equal_length(input):
length = len(input[0])
for input_slice in input[1:]:
if len(input_slice) != length:
return False
return True
def _inputs_are_equal(input_1, input_2):
if len(input_1) != len(input_2):
return False
for input_slice_1, input_slice_2 in zip(input_1, input_2):
if not np.allclose(np.asarray(input_slice_1), np.asarray(input_slice_2), atol=1e-3):
return False
return True
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify)
input_name = feat_extract.model_input_names[0]
processed_features = BatchFeature({input_name: speech_inputs})
# truncate to smallest
input_1 = feat_extract.pad(
processed_features, padding="max_length", max_length=len(speech_inputs[0]), truncation=True
)
input_1 = input_1[input_name]
input_2 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[0]))
input_2 = input_2[input_name]
self.assertTrue(_inputs_have_equal_length(input_1))
self.assertFalse(_inputs_have_equal_length(input_2))
# truncate to smallest with np
input_3 = feat_extract.pad(
processed_features,
padding="max_length",
max_length=len(speech_inputs[0]),
return_tensors="np",
truncation=True,
)
input_3 = input_3[input_name]
input_4 = feat_extract.pad(
processed_features, padding="max_length", max_length=len(speech_inputs[0]), return_tensors="np"
)
input_4 = input_4[input_name]
self.assertTrue(_inputs_have_equal_length(input_3))
self.assertTrue(input_3.shape[1] == len(speech_inputs[0]))
# since truncation forces padding to be smaller than longest input
# function can't return `np.ndarray`, but has to return list
self.assertFalse(_inputs_have_equal_length(input_4))
# truncate to middle
input_5 = feat_extract.pad(
processed_features,
padding="max_length",
max_length=len(speech_inputs[1]),
truncation=True,
return_tensors="np",
)
input_5 = input_5[input_name]
input_6 = feat_extract.pad(
processed_features, padding="max_length", max_length=len(speech_inputs[1]), truncation=True
)
input_6 = input_6[input_name]
input_7 = feat_extract.pad(
processed_features, padding="max_length", max_length=len(speech_inputs[1]), return_tensors="np"
)
input_7 = input_7[input_name]
self.assertTrue(input_5.shape[1] == len(speech_inputs[1]))
self.assertTrue(_inputs_have_equal_length(input_5))
self.assertTrue(_inputs_have_equal_length(input_6))
self.assertTrue(_inputs_are_equal(input_5, input_6))
# since truncation forces padding to be smaller than longest input
# function can't return `np.ndarray`, but has to return list
self.assertFalse(_inputs_have_equal_length(input_7))
self.assertTrue(len(input_7[-1]) == len(speech_inputs[-1]))
# padding has to be max_length when setting `truncation=True`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, truncation=True)[input_name]
# padding has to be max_length when setting `truncation=True`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="longest", truncation=True)[input_name]
# padding has to be max_length when setting `truncation=True`
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="longest", truncation=True)[input_name]
# max_length parameter has to be provided when setting `truncation=True` and padding="max_length"
with self.assertRaises(ValueError):
feat_extract.pad(processed_features, padding="max_length", truncation=True)[input_name]
# test truncation for `pad_to_multiple_of` for List[int] + numpy
pad_to_multiple_of = 12
input_8 = feat_extract.pad(
processed_features,
padding="max_length",
max_length=len(speech_inputs[0]),
pad_to_multiple_of=pad_to_multiple_of,
truncation=True,
)
input_8 = input_8[input_name]
input_9 = feat_extract.pad(
processed_features,
padding="max_length",
max_length=len(speech_inputs[0]),
pad_to_multiple_of=pad_to_multiple_of,
)
input_9 = input_9[input_name]
# retrieve expected_length as multiple of pad_to_multiple_of
expected_length = len(speech_inputs[0])
if expected_length % pad_to_multiple_of != 0:
expected_length = ((len(speech_inputs[0]) // pad_to_multiple_of) + 1) * pad_to_multiple_of
self.assertTrue(len(input_8[0]) == expected_length)
self.assertTrue(_inputs_have_equal_length(input_8))
self.assertFalse(_inputs_have_equal_length(input_9))
def test_padding_from_list(self):
self._check_padding(numpify=False)
def test_padding_from_array(self):
self._check_padding(numpify=True)
def test_truncation_from_list(self):
self._check_truncation(numpify=False)
def test_truncation_from_array(self):
self._check_truncation(numpify=True)
@require_torch
def test_padding_accepts_tensors_pt(self):
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
@ -251,3 +401,25 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
self.assertIn("attention_mask", processed)
self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lenghts)
def test_attention_mask_with_truncation(self):
feat_dict = self.feat_extract_dict
feat_dict["return_attention_mask"] = True
feat_extract = self.feature_extraction_class(**feat_dict)
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
input_lenghts = [len(x) for x in speech_inputs]
input_name = feat_extract.model_input_names[0]
processed = BatchFeature({input_name: speech_inputs})
max_length = min(input_lenghts)
processed_pad = feat_extract.pad(
processed, padding="max_length", max_length=max_length, truncation=True, return_tensors="np"
)
self.assertIn("attention_mask", processed_pad)
self.assertListEqual(
list(processed_pad.attention_mask.shape), list((processed_pad[input_name].shape[0], max_length))
)
self.assertListEqual(
processed_pad.attention_mask[:, :max_length].sum(-1).tolist(), [max_length for x in speech_inputs]
)