mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[Examples] Generalise run audio classification for log-mel models (#21756)
* [Examples] Generalise run audio classification for log-mel models * batch feature extractor * make style
This commit is contained in:
parent
f7ca656f07
commit
13489248fa
@ -289,24 +289,27 @@ def main():
|
||||
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
|
||||
)
|
||||
|
||||
model_input_name = feature_extractor.model_input_names[0]
|
||||
|
||||
def train_transforms(batch):
|
||||
"""Apply train_transforms across a batch."""
|
||||
output_batch = {"input_values": []}
|
||||
subsampled_wavs = []
|
||||
for audio in batch[data_args.audio_column_name]:
|
||||
wav = random_subsample(
|
||||
audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
|
||||
)
|
||||
output_batch["input_values"].append(wav)
|
||||
subsampled_wavs.append(wav)
|
||||
inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
|
||||
output_batch = {model_input_name: inputs.get(model_input_name)}
|
||||
output_batch["labels"] = list(batch[data_args.label_column_name])
|
||||
|
||||
return output_batch
|
||||
|
||||
def val_transforms(batch):
|
||||
"""Apply val_transforms across a batch."""
|
||||
output_batch = {"input_values": []}
|
||||
for audio in batch[data_args.audio_column_name]:
|
||||
wav = audio["array"]
|
||||
output_batch["input_values"].append(wav)
|
||||
wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
|
||||
inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
|
||||
output_batch = {model_input_name: inputs.get(model_input_name)}
|
||||
output_batch["labels"] = list(batch[data_args.label_column_name])
|
||||
|
||||
return output_batch
|
||||
|
Loading…
Reference in New Issue
Block a user