mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enabling automatic loading of tokenizer with pipeline
for (#13376)
`audio-classification`.
This commit is contained in:
parent
e92140c567
commit
c9184a2e03
@ -449,6 +449,13 @@ def pipeline(
|
||||
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
|
||||
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
|
||||
|
||||
if task in {"audio-classification"}:
|
||||
# Audio classification will never require a tokenizer.
|
||||
# the model on the other hand might have a tokenizer, but
|
||||
# the files could be missing from the hub, instead of failing
|
||||
# on such repos, we just force to not load it.
|
||||
load_tokenizer = False
|
||||
|
||||
if load_tokenizer:
|
||||
# Try to infer tokenizer from model or config name (if provided as str)
|
||||
if tokenizer is None:
|
||||
|
@ -16,7 +16,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, PreTrainedTokenizer
|
||||
from transformers import MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
from transformers.pipelines import AudioClassificationPipeline, pipeline
|
||||
from transformers.testing_utils import (
|
||||
is_pipeline_test,
|
||||
@ -77,9 +77,7 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
def test_small_model_pt(self):
|
||||
model = "anton-l/wav2vec2-random-tiny-classifier"
|
||||
|
||||
# hack: dummy tokenizer is required to prevent pipeline from failing
|
||||
tokenizer = PreTrainedTokenizer()
|
||||
audio_classifier = pipeline("audio-classification", model=model, tokenizer=tokenizer)
|
||||
audio_classifier = pipeline("audio-classification", model=model)
|
||||
|
||||
audio = np.ones((8000,))
|
||||
output = audio_classifier(audio, top_k=4)
|
||||
@ -101,9 +99,7 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
|
||||
model = "superb/wav2vec2-base-superb-ks"
|
||||
|
||||
# hack: dummy tokenizer is required to prevent pipeline from failing
|
||||
tokenizer = PreTrainedTokenizer()
|
||||
audio_classifier = pipeline("audio-classification", model=model, tokenizer=tokenizer)
|
||||
audio_classifier = pipeline("audio-classification", model=model)
|
||||
dataset = datasets.load_dataset("anton-l/superb_dummy", "ks", split="test")
|
||||
|
||||
audio = np.array(dataset[3]["speech"], dtype=np.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user