mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[Sequence Feature Extraction] Add truncation (#12804)
* fix_torch_device_generate_test * remove @ * add truncate * finish * correct test * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * clean tests * correct normalization for truncation * remove casting * up * save intermed * finish * finish * correct Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
98364ea74f
commit
f6e254474c
@ -69,6 +69,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
||||
],
|
||||
padding: Union[bool, str, PaddingStrategy] = True,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: bool = False,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
@ -107,6 +108,8 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
truncation (:obj:`bool`):
|
||||
Activates truncation to cut input sequences longer than :obj:`max_length` to :obj:`max_length`.
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
@ -178,10 +181,18 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
||||
processed_features[key] = to_py_obj(value)
|
||||
|
||||
# Convert padding_strategy in PaddingStrategy
|
||||
padding_strategy, max_length, _ = self._get_padding_strategies(padding=padding, max_length=max_length)
|
||||
padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
|
||||
|
||||
required_input = processed_features[self.model_input_names[0]]
|
||||
if required_input and not isinstance(required_input[0], (list, tuple)):
|
||||
# truncation
|
||||
processed_features = self._truncate(
|
||||
processed_features,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
truncation=truncation,
|
||||
)
|
||||
# padding
|
||||
processed_features = self._pad(
|
||||
processed_features,
|
||||
max_length=max_length,
|
||||
@ -196,13 +207,32 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
||||
len(v) == batch_size for v in processed_features.values()
|
||||
), "Some items in the output dictionary have a different batch size than others."
|
||||
|
||||
truncated_inputs = []
|
||||
for i in range(batch_size):
|
||||
inputs = dict((k, v[i]) for k, v in processed_features.items())
|
||||
# truncation
|
||||
inputs = self._truncate(
|
||||
inputs,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
truncation=truncation,
|
||||
)
|
||||
truncated_inputs.append(inputs)
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = max(len(inputs) for inputs in required_input)
|
||||
padding_strategy = PaddingStrategy.MAX_LENGTH
|
||||
|
||||
batch_outputs = {}
|
||||
for i in range(batch_size):
|
||||
inputs = dict((k, v[i]) for k, v in processed_features.items())
|
||||
# truncation
|
||||
inputs = self._truncate(
|
||||
truncated_inputs[i],
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
truncation=truncation,
|
||||
)
|
||||
# padding
|
||||
outputs = self._pad(
|
||||
inputs,
|
||||
max_length=max_length,
|
||||
@ -278,6 +308,46 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
||||
|
||||
return processed_features
|
||||
|
||||
def _truncate(
|
||||
self,
|
||||
processed_features: Union[Dict[str, List[float]], BatchFeature],
|
||||
max_length: Optional[int] = None,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
truncation: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
Pad inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
processed_features: Dictionary of input values (`List[float]`) / input vectors (`List[List[float]]`) or batch of inputs values (`List[List[int]]`) / input vectors (`List[List[List[int]]]`)
|
||||
max_length: maximum length of the returned list and optionally padding length (see below)
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
|
||||
truncation: (optional) Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
"""
|
||||
if not truncation:
|
||||
return processed_features
|
||||
elif truncation and max_length is None:
|
||||
raise ValueError(
|
||||
"When setting ``truncation=True``, make sure that ``max_length`` is defined and ``padding='max_length'``"
|
||||
)
|
||||
|
||||
required_input = processed_features[self.model_input_names[0]]
|
||||
|
||||
# find `max_length` that fits `pad_to_multiple_of`
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_truncated = len(required_input) > max_length
|
||||
|
||||
if needs_to_be_truncated:
|
||||
processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length]
|
||||
if "attention_mask" in processed_features:
|
||||
processed_features["attention_mask"] = processed_features["attention_mask"][:max_length]
|
||||
|
||||
return processed_features
|
||||
|
||||
def _get_padding_strategies(self, padding=False, max_length=None, pad_to_multiple_of=None, **kwargs):
|
||||
"""
|
||||
Find the correct padding strategy
|
||||
@ -308,4 +378,4 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
|
||||
"Please select a value to use as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
|
||||
)
|
||||
|
||||
return padding_strategy, max_length, kwargs
|
||||
return padding_strategy
|
||||
|
@ -93,28 +93,37 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
@staticmethod
|
||||
def utterance_cmvn(
|
||||
x: np.ndarray, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True
|
||||
x: np.ndarray, input_length: int, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True
|
||||
) -> np.ndarray:
|
||||
mean = x.mean(axis=0)
|
||||
square_sums = (x ** 2).sum(axis=0)
|
||||
# make sure we normalie float32 arrays
|
||||
|
||||
mean = x[:input_length].mean(axis=0)
|
||||
square_sums = (x[:input_length] ** 2).sum(axis=0)
|
||||
|
||||
if normalize_means:
|
||||
x = np.subtract(x, mean)
|
||||
if normalize_vars:
|
||||
var = square_sums / x.shape[0] - mean ** 2
|
||||
var = square_sums / x[:input_length].shape[0] - mean ** 2
|
||||
std = np.sqrt(np.maximum(var, 1e-10))
|
||||
x = np.divide(x, std)
|
||||
|
||||
# make sure array is in float32
|
||||
x = x.astype(np.float32)
|
||||
|
||||
return x
|
||||
|
||||
def normalize(self, input_values: List[np.ndarray]) -> List[np.ndarray]:
|
||||
return [self.utterance_cmvn(x, self.normalize_means, self.normalize_vars) for x in input_values]
|
||||
def normalize(self, input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
|
||||
return [
|
||||
self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars)
|
||||
for x, n in zip(input_values, input_lengths)
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: bool = False,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
sampling_rate: Optional[int] = None,
|
||||
@ -140,6 +149,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
truncation (:obj:`bool`):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
@ -191,6 +202,8 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
raw_speech = [np.asarray(speech) for speech in raw_speech]
|
||||
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
||||
raw_speech = np.asarray(raw_speech)
|
||||
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
|
||||
raw_speech = raw_speech.astype(np.float32)
|
||||
|
||||
# always return batch
|
||||
if not is_batched:
|
||||
@ -199,10 +212,6 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
# extract fbank features
|
||||
features = [self._extract_fbank_features(waveform) for waveform in raw_speech]
|
||||
|
||||
# Utterance-level cepstral mean and variance normalization
|
||||
if self.do_ceptral_normalize:
|
||||
features = self.normalize(features)
|
||||
|
||||
# convert into correct format for padding
|
||||
encoded_inputs = BatchFeature({"input_features": features})
|
||||
|
||||
@ -210,10 +219,29 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
|
||||
encoded_inputs,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if "attention_mask" in padded_inputs:
|
||||
input_lengths = padded_inputs["attention_mask"].sum(-1)
|
||||
else:
|
||||
padded_input_values = padded_inputs["input_features"]
|
||||
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
|
||||
|
||||
# Utterance-level cepstral mean and variance normalization
|
||||
if self.do_ceptral_normalize:
|
||||
input_features = padded_inputs["input_features"]
|
||||
|
||||
# make sure list is in array format
|
||||
if isinstance(input_features[0], list):
|
||||
input_features = [np.asarray(feature, dtype=np.float32) for feature in input_features]
|
||||
|
||||
padded_inputs["input_features"] = self.normalize(input_features, input_lengths=input_lengths)
|
||||
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
return padded_inputs
|
||||
|
@ -758,7 +758,6 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = self._get_subsampled_encoder_attn_mask(attention_mask)
|
||||
|
||||
|
@ -79,17 +79,24 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
@staticmethod
|
||||
def zero_mean_unit_var_norm(input_values: List[np.ndarray]) -> List[np.ndarray]:
|
||||
def zero_mean_unit_var_norm(input_values: List[np.ndarray], input_lengths: List[int]) -> List[np.ndarray]:
|
||||
"""
|
||||
Every array in the list is normalized to have zero mean and unit variance
|
||||
"""
|
||||
return [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in input_values]
|
||||
if isinstance(input_values[0], np.ndarray):
|
||||
input_values = [x.astype(np.float32) for x in input_values]
|
||||
|
||||
normed_input_values = [
|
||||
(x - np.mean(x[:i])) / np.sqrt(np.var(x[:i]) + 1e-5) for x, i in zip(input_values, input_lengths)
|
||||
]
|
||||
return normed_input_values
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
||||
padding: Union[bool, str, PaddingStrategy] = False,
|
||||
max_length: Optional[int] = None,
|
||||
truncation: bool = False,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
@ -115,6 +122,8 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
||||
different lengths).
|
||||
max_length (:obj:`int`, `optional`):
|
||||
Maximum length of the returned list and optionally padding length (see above).
|
||||
truncation (:obj:`bool`):
|
||||
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
|
||||
pad_to_multiple_of (:obj:`int`, `optional`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
@ -168,18 +177,16 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
||||
|
||||
# make sure input is in list format
|
||||
if is_batched and not isinstance(raw_speech[0], np.ndarray):
|
||||
raw_speech = [np.asarray(speech) for speech in raw_speech]
|
||||
raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
|
||||
elif not is_batched and not isinstance(raw_speech, np.ndarray):
|
||||
raw_speech = np.asarray(raw_speech)
|
||||
raw_speech = np.asarray(raw_speech, dtype=np.float32)
|
||||
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
|
||||
raw_speech = raw_speech.astype(np.float32)
|
||||
|
||||
# always return batch
|
||||
if not is_batched:
|
||||
raw_speech = [raw_speech]
|
||||
|
||||
# zero-mean and unit-variance normalization
|
||||
if self.do_normalize:
|
||||
raw_speech = self.zero_mean_unit_var_norm(raw_speech)
|
||||
|
||||
# convert into correct format for padding
|
||||
encoded_inputs = BatchFeature({"input_values": raw_speech})
|
||||
|
||||
@ -187,9 +194,24 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
|
||||
encoded_inputs,
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
truncation=truncation,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
|
||||
if "attention_mask" in padded_inputs:
|
||||
input_lengths = padded_inputs["attention_mask"].sum(-1)
|
||||
else:
|
||||
padded_input_values = padded_inputs["input_values"]
|
||||
input_lengths = [padded_input_values.shape[-1] for _ in range(padded_input_values.shape[0])]
|
||||
|
||||
# zero-mean and unit-variance normalization
|
||||
if self.do_normalize:
|
||||
padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
|
||||
padded_inputs["input_values"], input_lengths=input_lengths
|
||||
)
|
||||
|
||||
if return_tensors is not None:
|
||||
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
|
||||
|
||||
return padded_inputs
|
||||
|
@ -91,6 +91,7 @@ class Speech2TextFeatureExtractionTester(unittest.TestCase):
|
||||
if equal_length:
|
||||
speech_inputs = [floats_list((self.max_seq_length, self.feature_size)) for _ in range(self.batch_size)]
|
||||
else:
|
||||
# make sure that inputs increase in size
|
||||
speech_inputs = [
|
||||
floats_list((x, self.feature_size))
|
||||
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
||||
@ -147,3 +148,26 @@ class Speech2TextFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unitt
|
||||
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
_check_zero_mean_unit_variance(input_features[1, : fbank_feat_lengths[1]])
|
||||
_check_zero_mean_unit_variance(input_features[2, : fbank_feat_lengths[2]])
|
||||
|
||||
def test_cepstral_mean_and_variance_normalization_trunc(self):
|
||||
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
inputs = feature_extractor(
|
||||
speech_inputs,
|
||||
padding="max_length",
|
||||
max_length=4,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
input_features = inputs.input_features
|
||||
attention_mask = inputs.attention_mask
|
||||
fbank_feat_lengths = np.sum(attention_mask == 1, axis=1)
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.all(np.mean(input_vector, axis=0) < 1e-3))
|
||||
self.assertTrue(np.all(np.abs(np.var(input_vector, axis=0) - 1) < 1e-3))
|
||||
|
||||
_check_zero_mean_unit_variance(input_features[0, : fbank_feat_lengths[0]])
|
||||
_check_zero_mean_unit_variance(input_features[1])
|
||||
_check_zero_mean_unit_variance(input_features[2])
|
||||
|
@ -83,6 +83,7 @@ class Wav2Vec2FeatureExtractionTester(unittest.TestCase):
|
||||
if equal_length:
|
||||
speech_inputs = floats_list((self.batch_size, self.max_seq_length))
|
||||
else:
|
||||
# make sure that inputs increase in size
|
||||
speech_inputs = [
|
||||
_flatten(floats_list((x, self.feature_size)))
|
||||
for x in range(self.min_seq_length, self.max_seq_length, self.seq_length_diff)
|
||||
@ -122,7 +123,7 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
def test_zero_mean_unit_variance_normalization(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = feat_extract(speech_inputs, padding="longest")
|
||||
processed = feat_extract(speech_inputs, padding="longest", return_tensors="np")
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
@ -133,6 +134,22 @@ class Wav2Vec2FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
|
||||
_check_zero_mean_unit_variance(input_values[1, :1000])
|
||||
_check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
def test_zero_mean_unit_variance_normalization_trunc(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
|
||||
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||
processed = feat_extract(
|
||||
speech_inputs, truncation=True, max_length=1000, padding="max_length", return_tensors="np"
|
||||
)
|
||||
input_values = processed.input_values
|
||||
|
||||
def _check_zero_mean_unit_variance(input_vector):
|
||||
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||
|
||||
_check_zero_mean_unit_variance(input_values[0, :800])
|
||||
_check_zero_mean_unit_variance(input_values[1])
|
||||
_check_zero_mean_unit_variance(input_values[2])
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pretrained_checkpoints_are_set_correctly(self):
|
||||
|
@ -715,7 +715,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
input_speech = self._load_datasamples(2)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
|
||||
@ -737,7 +737,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = processor(input_speech, return_tensors="pt", padding=True)
|
||||
|
||||
input_values = inputs.input_values.to(torch_device)
|
||||
attention_mask = inputs.attention_mask.to(torch_device)
|
||||
|
@ -126,12 +126,17 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
|
||||
feature_size = self.feat_extract_tester.feature_size
|
||||
|
||||
# test padding for List[int] + numpy
|
||||
input_1 = feat_extract.pad(processed_features, padding=False)[input_name]
|
||||
input_2 = feat_extract.pad(processed_features, padding="longest")[input_name]
|
||||
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))[
|
||||
input_name
|
||||
]
|
||||
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")[input_name]
|
||||
input_1 = feat_extract.pad(processed_features, padding=False)
|
||||
input_1 = input_1[input_name]
|
||||
|
||||
input_2 = feat_extract.pad(processed_features, padding="longest")
|
||||
input_2 = input_2[input_name]
|
||||
|
||||
input_3 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[-1]))
|
||||
input_3 = input_3[input_name]
|
||||
|
||||
input_4 = feat_extract.pad(processed_features, padding="longest", return_tensors="np")
|
||||
input_4 = input_4[input_name]
|
||||
|
||||
# max_length parameter has to be provided when setting `padding="max_length"`
|
||||
with self.assertRaises(ValueError):
|
||||
@ -139,7 +144,8 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
|
||||
|
||||
input_5 = feat_extract.pad(
|
||||
processed_features, padding="max_length", max_length=pad_max_length, return_tensors="np"
|
||||
)[input_name]
|
||||
)
|
||||
input_5 = input_5[input_name]
|
||||
|
||||
self.assertFalse(_inputs_have_equal_length(input_1))
|
||||
self.assertTrue(_inputs_have_equal_length(input_2))
|
||||
@ -154,18 +160,25 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
|
||||
self.assertTrue(input_4.shape[2] == input_5.shape[2] == feature_size)
|
||||
|
||||
# test padding for `pad_to_multiple_of` for List[int] + numpy
|
||||
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)[input_name]
|
||||
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)[input_name]
|
||||
input_6 = feat_extract.pad(processed_features, pad_to_multiple_of=10)
|
||||
input_6 = input_6[input_name]
|
||||
|
||||
input_7 = feat_extract.pad(processed_features, padding="longest", pad_to_multiple_of=10)
|
||||
input_7 = input_7[input_name]
|
||||
|
||||
input_8 = feat_extract.pad(
|
||||
processed_features, padding="max_length", pad_to_multiple_of=10, max_length=pad_max_length
|
||||
)[input_name]
|
||||
)
|
||||
input_8 = input_8[input_name]
|
||||
|
||||
input_9 = feat_extract.pad(
|
||||
processed_features,
|
||||
padding="max_length",
|
||||
pad_to_multiple_of=10,
|
||||
max_length=pad_max_length,
|
||||
return_tensors="np",
|
||||
)[input_name]
|
||||
)
|
||||
input_9 = input_9[input_name]
|
||||
|
||||
self.assertTrue(all(len(x) % 10 == 0 for x in input_6))
|
||||
self.assertTrue(_inputs_are_equal(input_6, input_7))
|
||||
@ -205,12 +218,149 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
|
||||
< 1e-3
|
||||
)
|
||||
|
||||
def _check_truncation(self, numpify=False):
|
||||
def _inputs_have_equal_length(input):
|
||||
length = len(input[0])
|
||||
for input_slice in input[1:]:
|
||||
if len(input_slice) != length:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _inputs_are_equal(input_1, input_2):
|
||||
if len(input_1) != len(input_2):
|
||||
return False
|
||||
|
||||
for input_slice_1, input_slice_2 in zip(input_1, input_2):
|
||||
if not np.allclose(np.asarray(input_slice_1), np.asarray(input_slice_2), atol=1e-3):
|
||||
return False
|
||||
return True
|
||||
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common(numpify=numpify)
|
||||
input_name = feat_extract.model_input_names[0]
|
||||
|
||||
processed_features = BatchFeature({input_name: speech_inputs})
|
||||
|
||||
# truncate to smallest
|
||||
input_1 = feat_extract.pad(
|
||||
processed_features, padding="max_length", max_length=len(speech_inputs[0]), truncation=True
|
||||
)
|
||||
input_1 = input_1[input_name]
|
||||
|
||||
input_2 = feat_extract.pad(processed_features, padding="max_length", max_length=len(speech_inputs[0]))
|
||||
input_2 = input_2[input_name]
|
||||
|
||||
self.assertTrue(_inputs_have_equal_length(input_1))
|
||||
self.assertFalse(_inputs_have_equal_length(input_2))
|
||||
|
||||
# truncate to smallest with np
|
||||
input_3 = feat_extract.pad(
|
||||
processed_features,
|
||||
padding="max_length",
|
||||
max_length=len(speech_inputs[0]),
|
||||
return_tensors="np",
|
||||
truncation=True,
|
||||
)
|
||||
input_3 = input_3[input_name]
|
||||
|
||||
input_4 = feat_extract.pad(
|
||||
processed_features, padding="max_length", max_length=len(speech_inputs[0]), return_tensors="np"
|
||||
)
|
||||
input_4 = input_4[input_name]
|
||||
|
||||
self.assertTrue(_inputs_have_equal_length(input_3))
|
||||
self.assertTrue(input_3.shape[1] == len(speech_inputs[0]))
|
||||
|
||||
# since truncation forces padding to be smaller than longest input
|
||||
# function can't return `np.ndarray`, but has to return list
|
||||
self.assertFalse(_inputs_have_equal_length(input_4))
|
||||
|
||||
# truncate to middle
|
||||
input_5 = feat_extract.pad(
|
||||
processed_features,
|
||||
padding="max_length",
|
||||
max_length=len(speech_inputs[1]),
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
input_5 = input_5[input_name]
|
||||
|
||||
input_6 = feat_extract.pad(
|
||||
processed_features, padding="max_length", max_length=len(speech_inputs[1]), truncation=True
|
||||
)
|
||||
input_6 = input_6[input_name]
|
||||
|
||||
input_7 = feat_extract.pad(
|
||||
processed_features, padding="max_length", max_length=len(speech_inputs[1]), return_tensors="np"
|
||||
)
|
||||
input_7 = input_7[input_name]
|
||||
|
||||
self.assertTrue(input_5.shape[1] == len(speech_inputs[1]))
|
||||
self.assertTrue(_inputs_have_equal_length(input_5))
|
||||
self.assertTrue(_inputs_have_equal_length(input_6))
|
||||
self.assertTrue(_inputs_are_equal(input_5, input_6))
|
||||
|
||||
# since truncation forces padding to be smaller than longest input
|
||||
# function can't return `np.ndarray`, but has to return list
|
||||
self.assertFalse(_inputs_have_equal_length(input_7))
|
||||
self.assertTrue(len(input_7[-1]) == len(speech_inputs[-1]))
|
||||
|
||||
# padding has to be max_length when setting `truncation=True`
|
||||
with self.assertRaises(ValueError):
|
||||
feat_extract.pad(processed_features, truncation=True)[input_name]
|
||||
|
||||
# padding has to be max_length when setting `truncation=True`
|
||||
with self.assertRaises(ValueError):
|
||||
feat_extract.pad(processed_features, padding="longest", truncation=True)[input_name]
|
||||
|
||||
# padding has to be max_length when setting `truncation=True`
|
||||
with self.assertRaises(ValueError):
|
||||
feat_extract.pad(processed_features, padding="longest", truncation=True)[input_name]
|
||||
|
||||
# max_length parameter has to be provided when setting `truncation=True` and padding="max_length"
|
||||
with self.assertRaises(ValueError):
|
||||
feat_extract.pad(processed_features, padding="max_length", truncation=True)[input_name]
|
||||
|
||||
# test truncation for `pad_to_multiple_of` for List[int] + numpy
|
||||
pad_to_multiple_of = 12
|
||||
input_8 = feat_extract.pad(
|
||||
processed_features,
|
||||
padding="max_length",
|
||||
max_length=len(speech_inputs[0]),
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
truncation=True,
|
||||
)
|
||||
input_8 = input_8[input_name]
|
||||
|
||||
input_9 = feat_extract.pad(
|
||||
processed_features,
|
||||
padding="max_length",
|
||||
max_length=len(speech_inputs[0]),
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
)
|
||||
input_9 = input_9[input_name]
|
||||
|
||||
# retrieve expected_length as multiple of pad_to_multiple_of
|
||||
expected_length = len(speech_inputs[0])
|
||||
if expected_length % pad_to_multiple_of != 0:
|
||||
expected_length = ((len(speech_inputs[0]) // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
self.assertTrue(len(input_8[0]) == expected_length)
|
||||
self.assertTrue(_inputs_have_equal_length(input_8))
|
||||
self.assertFalse(_inputs_have_equal_length(input_9))
|
||||
|
||||
def test_padding_from_list(self):
|
||||
self._check_padding(numpify=False)
|
||||
|
||||
def test_padding_from_array(self):
|
||||
self._check_padding(numpify=True)
|
||||
|
||||
def test_truncation_from_list(self):
|
||||
self._check_truncation(numpify=False)
|
||||
|
||||
def test_truncation_from_array(self):
|
||||
self._check_truncation(numpify=True)
|
||||
|
||||
@require_torch
|
||||
def test_padding_accepts_tensors_pt(self):
|
||||
feat_extract = self.feature_extraction_class(**self.feat_extract_dict)
|
||||
@ -251,3 +401,25 @@ class SequenceFeatureExtractionTestMixin(FeatureExtractionSavingTestMixin):
|
||||
self.assertIn("attention_mask", processed)
|
||||
self.assertListEqual(list(processed.attention_mask.shape), list(processed[input_name].shape[:2]))
|
||||
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), input_lenghts)
|
||||
|
||||
def test_attention_mask_with_truncation(self):
|
||||
feat_dict = self.feat_extract_dict
|
||||
feat_dict["return_attention_mask"] = True
|
||||
feat_extract = self.feature_extraction_class(**feat_dict)
|
||||
speech_inputs = self.feat_extract_tester.prepare_inputs_for_common()
|
||||
input_lenghts = [len(x) for x in speech_inputs]
|
||||
input_name = feat_extract.model_input_names[0]
|
||||
|
||||
processed = BatchFeature({input_name: speech_inputs})
|
||||
max_length = min(input_lenghts)
|
||||
|
||||
processed_pad = feat_extract.pad(
|
||||
processed, padding="max_length", max_length=max_length, truncation=True, return_tensors="np"
|
||||
)
|
||||
self.assertIn("attention_mask", processed_pad)
|
||||
self.assertListEqual(
|
||||
list(processed_pad.attention_mask.shape), list((processed_pad[input_name].shape[0], max_length))
|
||||
)
|
||||
self.assertListEqual(
|
||||
processed_pad.attention_mask[:, :max_length].sum(-1).tolist(), [max_length for x in speech_inputs]
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user