[Wav2Vec2] Fix dtype 64 bug (#13517)

* fix

* 2nd fix
This commit is contained in:
Patrick von Platen 2021-09-10 18:19:10 +02:00 committed by GitHub
parent 72ec2f3eb5
commit a57d784df5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -210,7 +210,7 @@ class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
raw_speech = [np.asarray(speech) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.float64:
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
raw_speech = raw_speech.astype(np.float32)
# always return batch

View File

@ -207,10 +207,10 @@ class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
elif (
not isinstance(input_values, np.ndarray)
and isinstance(input_values[0], np.ndarray)
and input_values[0].dtype is np.float64
and input_values[0].dtype is np.dtype(np.float64)
):
padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values]
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.float64:
elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(np.float64):
padded_inputs["input_values"] = input_values.astype(np.float32)
# convert attention_mask to correct format