mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[CLAP] Support batched inputs for CLAP. Fixes pipeline issues (#21931)
* fix pipeline * fix feature_extraction clap * you can now batch the `is_longer` attribute * add tests * fixup * add expected scores * comment on is_longert
This commit is contained in:
parent
c5fe06c59d
commit
718e9d777f
@ -347,6 +347,9 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
|
||||
if isinstance(input_mel[0], List):
|
||||
input_mel = [np.asarray(feature, dtype=np.float64) for feature in input_mel]
|
||||
|
||||
# is_longer is a list of bool
|
||||
is_longer = [[longer] for longer in is_longer]
|
||||
|
||||
input_features = {"input_features": input_mel, "is_longer": is_longer}
|
||||
input_features = BatchFeature(input_features)
|
||||
|
||||
|
@ -44,6 +44,7 @@ class ZeroShotAudioClassificationPipeline(Pipeline):
|
||||
>>> audio = next(iter(dataset["train"]["audio"]))["array"]
|
||||
>>> classifier = pipeline(task="zero-shot-audio-classification", model="laion/clap-htsat-unfused")
|
||||
>>> classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
|
||||
[{'score': 0.9995999932289124, 'label': 'Sound of a dog'}, {'score': 0.00040007088682614267, 'label': 'Sound of vaccum cleaner'}]
|
||||
```
|
||||
|
||||
|
||||
@ -118,6 +119,7 @@ class ZeroShotAudioClassificationPipeline(Pipeline):
|
||||
sequences = [hypothesis_template.format(x) for x in candidate_labels]
|
||||
text_inputs = self.tokenizer(sequences, return_tensors=self.framework, padding=True)
|
||||
inputs["text_inputs"] = [text_inputs]
|
||||
return inputs
|
||||
|
||||
def _forward(self, model_inputs):
|
||||
candidate_labels = model_inputs.pop("candidate_labels")
|
||||
@ -131,8 +133,8 @@ class ZeroShotAudioClassificationPipeline(Pipeline):
|
||||
outputs = self.model(**text_inputs, **model_inputs)
|
||||
|
||||
model_outputs = {
|
||||
"candidate_label": candidate_labels,
|
||||
"logits_per_audio": outputs.logits_per_audio,
|
||||
"candidate_labels": candidate_labels,
|
||||
"logits": outputs.logits_per_audio,
|
||||
}
|
||||
return model_outputs
|
||||
|
||||
|
@ -665,3 +665,55 @@ class ClapModelIntegrationTest(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
torch.allclose(audio_embed.cpu().mean(), torch.tensor([expected_mean]), atol=1e-3, rtol=1e-3)
|
||||
)
|
||||
|
||||
def test_batched_fused(self):
|
||||
EXPECTED_MEANS_FUSED = {
|
||||
"repeatpad": 0.0010,
|
||||
"repeat": 0.0020,
|
||||
"pad": 0.0006,
|
||||
}
|
||||
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
audio_samples = [sample["array"] for sample in librispeech_dummy[0:4]["audio"]]
|
||||
|
||||
model_id = "laion/clap-htsat-fused"
|
||||
|
||||
model = ClapModel.from_pretrained(model_id).to(torch_device)
|
||||
processor = ClapProcessor.from_pretrained(model_id)
|
||||
|
||||
for padding in self.paddings:
|
||||
inputs = processor(audios=audio_samples, return_tensors="pt", padding=padding, truncation="fusion").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
audio_embed = model.get_audio_features(**inputs)
|
||||
expected_mean = EXPECTED_MEANS_FUSED[padding]
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(audio_embed.cpu().mean(), torch.tensor([expected_mean]), atol=1e-3, rtol=1e-3)
|
||||
)
|
||||
|
||||
def test_batched_unfused(self):
|
||||
EXPECTED_MEANS_FUSED = {
|
||||
"repeatpad": 0.0016,
|
||||
"repeat": 0.0019,
|
||||
"pad": 0.0019,
|
||||
}
|
||||
|
||||
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
audio_samples = [sample["array"] for sample in librispeech_dummy[0:4]["audio"]]
|
||||
|
||||
model_id = "laion/clap-htsat-unfused"
|
||||
|
||||
model = ClapModel.from_pretrained(model_id).to(torch_device)
|
||||
processor = ClapProcessor.from_pretrained(model_id)
|
||||
|
||||
for padding in self.paddings:
|
||||
inputs = processor(audios=audio_samples, return_tensors="pt", padding=padding).to(torch_device)
|
||||
|
||||
audio_embed = model.get_audio_features(**inputs)
|
||||
expected_mean = EXPECTED_MEANS_FUSED[padding]
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(audio_embed.cpu().mean(), torch.tensor([expected_mean]), atol=1e-3, rtol=1e-3)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user