mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Pipeline testing - using tiny models on Hub (#20426)
* rework pipeline tests * run pipeline tests * fix * fix * fix * revert the changes in get_test_pipeline() parameter list * fix expected error message * skip a test * clean up --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
a582cfce3c
commit
c749bd405e
@ -434,7 +434,7 @@ def create_circleci_config(folder=None):
|
||||
example_file = os.path.join(folder, "examples_test_list.txt")
|
||||
if os.path.exists(example_file) and os.path.getsize(example_file) > 0:
|
||||
jobs.extend(EXAMPLES_TESTS)
|
||||
|
||||
|
||||
repo_util_file = os.path.join(folder, "test_repo_utils.txt")
|
||||
if os.path.exists(repo_util_file) and os.path.getsize(repo_util_file) > 0:
|
||||
jobs.extend(REPO_UTIL_TESTS)
|
||||
|
@ -27,8 +27,8 @@ from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
audio_classifier = AudioClassificationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
audio_classifier = AudioClassificationPipeline(model=model, feature_extractor=processor)
|
||||
|
||||
# test with a raw waveform
|
||||
audio = np.zeros((34000,))
|
||||
|
@ -60,7 +60,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
+ (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else [])
|
||||
}
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
if tokenizer is None:
|
||||
# Side effect of no Fast Tokenizer class for these model, so skipping
|
||||
# But the slow tokenizer test should still run as they're quite small
|
||||
@ -69,7 +69,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
# return None, None
|
||||
|
||||
speech_recognizer = AutomaticSpeechRecognitionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
model=model, tokenizer=tokenizer, feature_extractor=processor
|
||||
)
|
||||
|
||||
# test with a raw waveform
|
||||
@ -133,7 +133,9 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
)
|
||||
else:
|
||||
# Non CTC models cannot use return_timestamps
|
||||
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "^We cannot return_timestamps yet on non-ctc models apart from Whisper !$"
|
||||
):
|
||||
outputs = speech_recognizer(audio, return_timestamps="char")
|
||||
|
||||
@require_torch
|
||||
|
@ -17,26 +17,20 @@ import importlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from unittest import skipIf
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, delete_repo, set_access_token
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
IMAGE_PROCESSOR_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoImageProcessor,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
DistilBertForSequenceClassification,
|
||||
@ -71,123 +65,16 @@ from test_module.custom_pipeline import PairClassificationPipeline # noqa E402
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ROBERTA_EMBEDDING_ADJUSMENT_CONFIGS = [
|
||||
"CamembertConfig",
|
||||
"IBertConfig",
|
||||
"LongformerConfig",
|
||||
"MarkupLMConfig",
|
||||
"RobertaConfig",
|
||||
"RobertaPreLayerNormConfig",
|
||||
"XLMRobertaConfig",
|
||||
]
|
||||
PATH_TO_TRANSFORMERS = os.path.join(Path(__file__).parent.parent.parent, "src/transformers")
|
||||
|
||||
|
||||
def get_checkpoint_from_architecture(architecture):
|
||||
try:
|
||||
module = importlib.import_module(architecture.__module__)
|
||||
except ImportError:
|
||||
logger.error(f"Ignoring architecture {architecture}")
|
||||
return
|
||||
|
||||
if hasattr(module, "_CHECKPOINT_FOR_DOC"):
|
||||
return module._CHECKPOINT_FOR_DOC
|
||||
else:
|
||||
logger.warning(f"Can't retrieve checkpoint from {architecture.__name__}")
|
||||
|
||||
|
||||
def get_tiny_config_from_class(configuration_class):
|
||||
if "OpenAIGPT" in configuration_class.__name__:
|
||||
# This is the only file that is inconsistent with the naming scheme.
|
||||
# Will rename this file if we decide this is the way to go
|
||||
return
|
||||
|
||||
model_type = configuration_class.model_type
|
||||
camel_case_model_name = configuration_class.__name__.split("Config")[0]
|
||||
|
||||
try:
|
||||
model_slug = model_type.replace("-", "_")
|
||||
module = importlib.import_module(f".test_modeling_{model_slug}", package=f"tests.models.{model_slug}")
|
||||
model_tester_class = getattr(module, f"{camel_case_model_name}ModelTester", None)
|
||||
except (ImportError, AttributeError):
|
||||
logger.error(f"No model tester class for {configuration_class.__name__}")
|
||||
return
|
||||
|
||||
if model_tester_class is None:
|
||||
logger.warning(f"No model tester class for {configuration_class.__name__}")
|
||||
return
|
||||
|
||||
model_tester = model_tester_class(parent=None)
|
||||
|
||||
if hasattr(model_tester, "get_pipeline_config"):
|
||||
config = model_tester.get_pipeline_config()
|
||||
elif hasattr(model_tester, "get_config"):
|
||||
config = model_tester.get_config()
|
||||
else:
|
||||
config = None
|
||||
logger.warning(f"Model tester {model_tester_class.__name__} has no `get_config()`.")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def get_tiny_tokenizer_from_checkpoint(checkpoint):
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
if tokenizer.vocab_size < 300:
|
||||
# Wav2Vec2ForCTC for instance
|
||||
# ByT5Tokenizer
|
||||
# all are already small enough and have no Fast version that can
|
||||
# be retrained
|
||||
return tokenizer
|
||||
logger.info("Training new from iterator ...")
|
||||
vocabulary = string.ascii_letters + string.digits + " "
|
||||
tokenizer = tokenizer.train_new_from_iterator(vocabulary, vocab_size=len(vocabulary), show_progress=False)
|
||||
logger.info("Trained.")
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config, feature_extractor_class):
|
||||
try:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint)
|
||||
except Exception:
|
||||
try:
|
||||
if feature_extractor_class is not None:
|
||||
feature_extractor = feature_extractor_class()
|
||||
else:
|
||||
feature_extractor = None
|
||||
except Exception:
|
||||
feature_extractor = None
|
||||
|
||||
# Audio Spectogram Transformer specific.
|
||||
if feature_extractor.__class__.__name__ == "ASTFeatureExtractor":
|
||||
feature_extractor = feature_extractor.__class__(
|
||||
max_length=tiny_config.max_length, num_mel_bins=tiny_config.num_mel_bins
|
||||
)
|
||||
|
||||
# Speech2TextModel specific.
|
||||
if hasattr(tiny_config, "input_feat_per_channel") and feature_extractor:
|
||||
feature_extractor = feature_extractor.__class__(
|
||||
feature_size=tiny_config.input_feat_per_channel, num_mel_bins=tiny_config.input_feat_per_channel
|
||||
)
|
||||
# TODO remove this, once those have been moved to `image_processor`.
|
||||
if hasattr(tiny_config, "image_size") and feature_extractor:
|
||||
feature_extractor = feature_extractor.__class__(size=tiny_config.image_size, crop_size=tiny_config.image_size)
|
||||
return feature_extractor
|
||||
|
||||
|
||||
def get_tiny_image_processor_from_checkpoint(checkpoint, tiny_config, image_processor_class):
|
||||
try:
|
||||
image_processor = AutoImageProcessor.from_pretrained(checkpoint)
|
||||
except Exception:
|
||||
try:
|
||||
if image_processor_class is not None:
|
||||
image_processor = image_processor_class()
|
||||
else:
|
||||
image_processor = None
|
||||
except Exception:
|
||||
image_processor = None
|
||||
if hasattr(tiny_config, "image_size") and image_processor:
|
||||
image_processor = image_processor.__class__(size=tiny_config.image_size, crop_size=tiny_config.image_size)
|
||||
return image_processor
|
||||
# Dynamically import the Transformers module to grab the attribute classes of the processor form their names.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_TRANSFORMERS],
|
||||
)
|
||||
transformers_module = spec.loader.load_module()
|
||||
|
||||
|
||||
class ANY:
|
||||
@ -201,76 +88,171 @@ class ANY:
|
||||
return f"ANY({', '.join(_type.__name__ for _type in self._types)})"
|
||||
|
||||
|
||||
def is_test_to_skip(test_casse_name, config_class, model_architecture, tokenizer_name, processor_name):
|
||||
"""Some tests are just not working"""
|
||||
|
||||
to_skip = False
|
||||
|
||||
if config_class.__name__ == "RoCBertConfig" and test_casse_name in [
|
||||
"FillMaskPipelineTests",
|
||||
"FeatureExtractionPipelineTests",
|
||||
"TextClassificationPipelineTests",
|
||||
"TokenClassificationPipelineTests",
|
||||
]:
|
||||
# Get error: IndexError: index out of range in self.
|
||||
# `word_shape_file` and `word_pronunciation_file` should be shrunk during tiny model creation,
|
||||
# otherwise `IndexError` could occur in some embedding layers. Skip for now until this model has
|
||||
# more usage.
|
||||
to_skip = True
|
||||
elif config_class.__name__ in ["LayoutLMv3Config", "LiltConfig"]:
|
||||
# Get error: ValueError: Words must be of type `List[str]`. Previously, `LayoutLMv3` is not
|
||||
# used in pipeline tests as it could not find a checkpoint
|
||||
# TODO: check and fix if possible
|
||||
to_skip = True
|
||||
# config/model class we decide to skip
|
||||
elif config_class.__name__ in ["TapasConfig"]:
|
||||
# Get error: AssertionError: Table must be of type pd.DataFrame. Also, the tiny model has large
|
||||
# vocab size as the fast tokenizer could not be converted. Previous, `Tapas` is not used in
|
||||
# pipeline tests due to the same reason.
|
||||
# TODO: check and fix if possible
|
||||
to_skip = True
|
||||
|
||||
# TODO: check and fix if possible
|
||||
if not to_skip and tokenizer_name is not None:
|
||||
|
||||
if (
|
||||
test_casse_name == "QAPipelineTests"
|
||||
and not tokenizer_name.endswith("Fast")
|
||||
and config_class.__name__
|
||||
in [
|
||||
"FlaubertConfig",
|
||||
"GPTJConfig",
|
||||
"LongformerConfig",
|
||||
"MvpConfig",
|
||||
"OPTConfig",
|
||||
"ReformerConfig",
|
||||
"XLMConfig",
|
||||
]
|
||||
):
|
||||
# `QAPipelineTests` fails for a few models when the slower tokenizer are used.
|
||||
# (The slower tokenizers were never used for pipeline tests before the pipeline testing rework)
|
||||
# TODO: check (and possibly fix) the `QAPipelineTests` with slower tokenizer
|
||||
to_skip = True
|
||||
elif test_casse_name == "ZeroShotClassificationPipelineTests" and config_class.__name__ in [
|
||||
"CTRLConfig",
|
||||
"OpenAIGPTConfig",
|
||||
]:
|
||||
# Get `tokenizer does not have a padding token` error for both fast/slow tokenizers.
|
||||
# `CTRLConfig` and `OpenAIGPTConfig` were never used in pipeline tests, either because of a missing
|
||||
# checkpoint or because a tiny config could not be created
|
||||
to_skip = True
|
||||
elif test_casse_name == "TranslationPipelineTests" and config_class.__name__ in [
|
||||
"M2M100Config",
|
||||
"PLBartConfig",
|
||||
]:
|
||||
# Get `ValueError: Translation requires a `src_lang` and a `tgt_lang` for this model`.
|
||||
# `M2M100Config` and `PLBartConfig` were never used in pipeline tests: cannot create a simple tokenizer
|
||||
to_skip = True
|
||||
elif test_casse_name == "TextGenerationPipelineTests" and config_class.__name__ in [
|
||||
"ProphetNetConfig",
|
||||
"TransfoXLConfig",
|
||||
]:
|
||||
# Get `ValueError: AttributeError: 'NoneType' object has no attribute 'new_ones'` or `AssertionError`.
|
||||
# `TransfoXLConfig` and `ProphetNetConfig` were never used in pipeline tests: cannot create a simple
|
||||
# tokenizer.
|
||||
to_skip = True
|
||||
elif test_casse_name == "FillMaskPipelineTests" and config_class.__name__ in [
|
||||
"FlaubertConfig",
|
||||
"XLMConfig",
|
||||
]:
|
||||
# Get `ValueError: AttributeError: 'NoneType' object has no attribute 'new_ones'` or `AssertionError`.
|
||||
# `FlaubertConfig` and `TransfoXLConfig` were never used in pipeline tests: cannot create a simple
|
||||
# tokenizer
|
||||
to_skip = True
|
||||
elif test_casse_name == "TextGenerationPipelineTests" and model_architecture.__name__ in [
|
||||
"TFRoFormerForCausalLM"
|
||||
]:
|
||||
# TODO: add `prepare_inputs_for_generation` for `TFRoFormerForCausalLM`
|
||||
to_skip = True
|
||||
elif test_casse_name == "QAPipelineTests" and model_architecture.__name__ in ["FNetForQuestionAnswering"]:
|
||||
# TODO: The change in `base.py` in the PR #21132 (https://github.com/huggingface/transformers/pull/21132)
|
||||
# fails this test case. Skip for now - a fix for this along with the initial changes in PR #20426 is
|
||||
# too much. Let `ydshieh` to fix it ASAP once #20426 is merged.
|
||||
to_skip = True
|
||||
|
||||
return to_skip
|
||||
|
||||
|
||||
def validate_test_components(test_case, model, tokenizer, processor):
|
||||
|
||||
# TODO: Move this to tiny model creation script
|
||||
# head-specific (within a model type) necessary changes to the config
|
||||
# 1. for `BlenderbotForCausalLM`
|
||||
if model.__class__.__name__ == "BlenderbotForCausalLM":
|
||||
model.config.encoder_no_repeat_ngram_size = 0
|
||||
|
||||
# TODO: Change the tiny model creation script: don't create models with problematic tokenizers
|
||||
# Avoid `IndexError` in embedding layers
|
||||
CONFIG_WITHOUT_VOCAB_SIZE = ["CanineConfig"]
|
||||
if tokenizer is not None:
|
||||
config_vocab_size = getattr(model.config, "vocab_size", None)
|
||||
# For CLIP-like models
|
||||
if config_vocab_size is None and hasattr(model.config, "text_config"):
|
||||
config_vocab_size = getattr(model.config.text_config, "vocab_size", None)
|
||||
if config_vocab_size is None and model.config.__class__.__name__ not in CONFIG_WITHOUT_VOCAB_SIZE:
|
||||
raise ValueError(
|
||||
"Could not determine `vocab_size` from model configuration while `tokenizer` is not `None`."
|
||||
)
|
||||
# TODO: Remove tiny models from the Hub which have problematic tokenizers (but still keep this block)
|
||||
if config_vocab_size is not None and len(tokenizer) > config_vocab_size:
|
||||
test_case.skipTest(
|
||||
f"Ignore {model.__class__.__name__}: `tokenizer` ({tokenizer.__class__.__name__}) has"
|
||||
f" {len(tokenizer)} tokens which is greater than `config_vocab_size`"
|
||||
f" ({config_vocab_size}). Something is wrong."
|
||||
)
|
||||
|
||||
|
||||
class PipelineTestCaseMeta(type):
|
||||
def __new__(mcs, name, bases, dct):
|
||||
def gen_test(
|
||||
ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class, image_processor_class
|
||||
):
|
||||
def gen_test(repo_name, model_architecture, tokenizer_name, processor_name):
|
||||
@skipIf(
|
||||
tiny_config is None,
|
||||
"TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling"
|
||||
" file",
|
||||
)
|
||||
@skipIf(
|
||||
checkpoint is None,
|
||||
"checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the"
|
||||
" modeling file",
|
||||
tokenizer_name is None and processor_name is None,
|
||||
f"Ignore {model_architecture.__name__}: no processor class is provided (tokenizer, image processor,"
|
||||
" feature extractor, etc)",
|
||||
)
|
||||
def test(self):
|
||||
if ModelClass.__name__.endswith("ForCausalLM"):
|
||||
tiny_config.is_encoder_decoder = False
|
||||
if hasattr(tiny_config, "encoder_no_repeat_ngram_size"):
|
||||
# specific for blenderbot which supports both decoder-only
|
||||
# encoder/decoder but the test config only reflects
|
||||
# encoder/decoder arch
|
||||
tiny_config.encoder_no_repeat_ngram_size = 0
|
||||
if ModelClass.__name__.endswith("WithLMHead"):
|
||||
tiny_config.is_decoder = True
|
||||
repo_id = f"hf-internal-testing/{repo_name}"
|
||||
|
||||
tokenizer = None
|
||||
if tokenizer_name is not None:
|
||||
tokenizer_class = getattr(transformers_module, tokenizer_name)
|
||||
tokenizer = tokenizer_class.from_pretrained(repo_id)
|
||||
|
||||
processor = None
|
||||
if processor_name is not None:
|
||||
processor_class = getattr(transformers_module, processor_name)
|
||||
# If the required packages (like `Pillow`) are not installed, this will fail.
|
||||
try:
|
||||
processor = processor_class.from_pretrained(repo_id)
|
||||
except Exception:
|
||||
self.skipTest(f"Ignore {model_architecture.__name__}: could not load the model from {repo_id}")
|
||||
|
||||
try:
|
||||
model = ModelClass(tiny_config)
|
||||
except ImportError as e:
|
||||
self.skipTest(
|
||||
f"Cannot run with {tiny_config} as the model requires a library that isn't installed: {e}"
|
||||
)
|
||||
model = model_architecture.from_pretrained(repo_id)
|
||||
except Exception:
|
||||
self.skipTest(f"Ignore {model_architecture.__name__}: could not load the model from {repo_id}")
|
||||
|
||||
# validate
|
||||
validate_test_components(self, model, tokenizer, processor)
|
||||
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
if tokenizer_class is not None:
|
||||
try:
|
||||
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
|
||||
# XLNet actually defines it as -1.
|
||||
if model.config.__class__.__name__ in ROBERTA_EMBEDDING_ADJUSMENT_CONFIGS:
|
||||
tokenizer.model_max_length = model.config.max_position_embeddings - 2
|
||||
elif (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
and model.config.max_position_embeddings > 0
|
||||
):
|
||||
tokenizer.model_max_length = model.config.max_position_embeddings
|
||||
# Rust Panic exception are NOT Exception subclass
|
||||
# Some test tokenizer contain broken vocabs or custom PreTokenizer, so we
|
||||
# provide some default tokenizer and hope for the best.
|
||||
except: # noqa: E722
|
||||
self.skipTest(f"Ignoring {ModelClass}, cannot create a simple tokenizer")
|
||||
else:
|
||||
tokenizer = None
|
||||
|
||||
feature_extractor = get_tiny_feature_extractor_from_checkpoint(
|
||||
checkpoint, tiny_config, feature_extractor_class
|
||||
)
|
||||
|
||||
image_processor = get_tiny_image_processor_from_checkpoint(
|
||||
checkpoint, tiny_config, image_processor_class
|
||||
)
|
||||
|
||||
if tokenizer is None and feature_extractor is None and image_processor:
|
||||
self.skipTest(
|
||||
f"Ignoring {ModelClass}, cannot create a tokenizer or feature_extractor or image_processor"
|
||||
" (PerceiverConfig with no FastTokenizer ?)"
|
||||
)
|
||||
pipeline, examples = self.get_test_pipeline(model, tokenizer, feature_extractor, image_processor)
|
||||
pipeline, examples = self.get_test_pipeline(model, tokenizer, processor)
|
||||
if pipeline is None:
|
||||
# The test can disable itself, but it should be very marginal
|
||||
# Concerns: Wav2Vec2ForCTC without tokenizer test (FastTokenizer don't exist)
|
||||
return
|
||||
self.skipTest(f"Ignore {model_architecture.__name__}: could not create the pipeline")
|
||||
self.run_pipeline_test(pipeline, examples)
|
||||
|
||||
def run_batch_test(pipeline, examples):
|
||||
@ -294,52 +276,45 @@ class PipelineTestCaseMeta(type):
|
||||
|
||||
return test
|
||||
|
||||
# Download tiny model summary (used to avoid requesting from Hub too many times)
|
||||
url = "https://huggingface.co/datasets/hf-internal-testing/tiny-random-model-summary/raw/main/processor_classes.json"
|
||||
tiny_model_summary = requests.get(url).json()
|
||||
|
||||
for prefix, key in [("pt", "model_mapping"), ("tf", "tf_model_mapping")]:
|
||||
mapping = dct.get(key, {})
|
||||
if mapping:
|
||||
for configuration, model_architectures in mapping.items():
|
||||
for config_class, model_architectures in mapping.items():
|
||||
|
||||
if not isinstance(model_architectures, tuple):
|
||||
model_architectures = (model_architectures,)
|
||||
|
||||
for model_architecture in model_architectures:
|
||||
checkpoint = get_checkpoint_from_architecture(model_architecture)
|
||||
tiny_config = get_tiny_config_from_class(configuration)
|
||||
tokenizer_classes = TOKENIZER_MAPPING.get(configuration, [])
|
||||
feature_extractor_class = FEATURE_EXTRACTOR_MAPPING.get(configuration, None)
|
||||
feature_extractor_name = (
|
||||
feature_extractor_class.__name__ if feature_extractor_class else "nofeature_extractor"
|
||||
)
|
||||
image_processor_class = IMAGE_PROCESSOR_MAPPING.get(configuration, None)
|
||||
image_processor_name = (
|
||||
image_processor_class.__name__ if image_processor_class else "noimage_processor"
|
||||
)
|
||||
if not tokenizer_classes:
|
||||
# We need to test even if there are no tokenizers.
|
||||
tokenizer_classes = [None]
|
||||
else:
|
||||
# Remove the non defined tokenizers
|
||||
# ByT5 and Perceiver are bytes-level and don't define
|
||||
# FastTokenizer, we can just ignore those.
|
||||
tokenizer_classes = [
|
||||
tokenizer_class for tokenizer_class in tokenizer_classes if tokenizer_class is not None
|
||||
]
|
||||
model_arch_name = model_architecture.__name__
|
||||
# Get the canonical name
|
||||
for _prefix in ["Flax", "TF"]:
|
||||
if model_arch_name.startswith(_prefix):
|
||||
model_arch_name = model_arch_name[len(_prefix) :]
|
||||
break
|
||||
|
||||
for tokenizer_class in tokenizer_classes:
|
||||
if tokenizer_class is not None:
|
||||
tokenizer_name = tokenizer_class.__name__
|
||||
else:
|
||||
tokenizer_name = "notokenizer"
|
||||
tokenizer_names = []
|
||||
processor_names = []
|
||||
if model_arch_name in tiny_model_summary:
|
||||
tokenizer_names = tiny_model_summary[model_arch_name]["tokenizer_classes"]
|
||||
processor_names = tiny_model_summary[model_arch_name]["processor_classes"]
|
||||
# Adding `None` (if empty) so we can generate tests
|
||||
tokenizer_names = [None] if len(tokenizer_names) == 0 else tokenizer_names
|
||||
processor_names = [None] if len(processor_names) == 0 else processor_names
|
||||
|
||||
test_name = f"test_{prefix}_{configuration.__name__}_{model_architecture.__name__}_{tokenizer_name}_{feature_extractor_name}_{image_processor_name}"
|
||||
|
||||
if tokenizer_class is not None or feature_extractor_class is not None:
|
||||
repo_name = f"tiny-random-{model_arch_name}"
|
||||
for tokenizer_name in tokenizer_names:
|
||||
for processor_name in processor_names:
|
||||
if is_test_to_skip(
|
||||
name, config_class, model_architecture, tokenizer_name, processor_name
|
||||
):
|
||||
continue
|
||||
test_name = f"test_{prefix}_{config_class.__name__}_{model_architecture.__name__}_{tokenizer_name}_{processor_name}"
|
||||
dct[test_name] = gen_test(
|
||||
model_architecture,
|
||||
checkpoint,
|
||||
tiny_config,
|
||||
tokenizer_class,
|
||||
feature_extractor_class,
|
||||
image_processor_class,
|
||||
repo_name, model_architecture, tokenizer_name, processor_name
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
@ -53,7 +53,7 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
else []
|
||||
)
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
|
||||
return conversation_agent, [Conversation("Hi there!")]
|
||||
|
||||
|
@ -47,8 +47,8 @@ class DepthEstimationPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
|
||||
|
||||
model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
depth_estimator = DepthEstimationPipeline(model=model, feature_extractor=feature_extractor)
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
depth_estimator = DepthEstimationPipeline(model=model, feature_extractor=processor)
|
||||
return depth_estimator, [
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
|
@ -59,9 +59,9 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
|
||||
|
||||
@require_pytesseract
|
||||
@require_vision
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
dqa_pipeline = pipeline(
|
||||
"document-question-answering", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
"document-question-answering", model=model, tokenizer=tokenizer, feature_extractor=processor
|
||||
)
|
||||
|
||||
image = INVOICE_URL
|
||||
|
@ -175,7 +175,7 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
raise ValueError("We expect lists of floats, nothing else")
|
||||
return shape
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
if tokenizer is None:
|
||||
self.skipTest("No tokenizer")
|
||||
return
|
||||
@ -196,9 +196,7 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
)
|
||||
|
||||
return
|
||||
feature_extractor = FeatureExtractionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
)
|
||||
feature_extractor = FeatureExtractionPipeline(model=model, tokenizer=tokenizer, feature_extractor=processor)
|
||||
return feature_extractor, ["This is a test", "This is another test"]
|
||||
|
||||
def run_pipeline_test(self, feature_extractor, examples):
|
||||
|
@ -206,7 +206,7 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
unmasker.tokenizer.pad_token = None
|
||||
self.run_pipeline_test(unmasker, [])
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
if tokenizer is None or tokenizer.mask_token_id is None:
|
||||
self.skipTest("The provided tokenizer has no mask token, (probably reformer or wav2vec2)")
|
||||
|
||||
|
@ -49,8 +49,8 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=processor, top_k=2)
|
||||
examples = [
|
||||
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
"http://images.cocodataset.org/val2017/000000039769.jpg",
|
||||
|
@ -81,10 +81,8 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
||||
+ (MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING.items() if MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING else [])
|
||||
}
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
image_segmenter = ImageSegmentationPipeline(
|
||||
model=model, feature_extractor=feature_extractor, image_processor=image_processor
|
||||
)
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
image_segmenter = ImageSegmentationPipeline(model=model, image_processor=processor)
|
||||
return image_segmenter, [
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
|
@ -36,8 +36,8 @@ class ImageToTextPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
|
||||
model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
pipe = pipeline("image-to-text", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
pipe = pipeline("image-to-text", model=model, tokenizer=tokenizer, feature_extractor=processor)
|
||||
examples = [
|
||||
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
|
||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||
|
@ -51,8 +51,8 @@ else:
|
||||
class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=processor)
|
||||
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
|
||||
|
||||
def run_pipeline_test(self, object_detector, examples):
|
||||
|
@ -31,7 +31,7 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
if isinstance(model.config, LxmertConfig):
|
||||
# This is an bimodal model, we need to find a more consistent way
|
||||
# to switch on those models.
|
||||
|
@ -34,7 +34,7 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
|
||||
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
summarizer = SummarizationPipeline(model=model, tokenizer=tokenizer)
|
||||
return summarizer, ["(CNN)The Palestinian Authority officially became", "Some other text"]
|
||||
|
||||
|
@ -34,7 +34,7 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
generator = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
||||
return generator, ["Something to write", "Something else"]
|
||||
|
||||
|
@ -129,7 +129,7 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
|
||||
outputs = text_classifier("Birds are a type of animal")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 0.988}])
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||
return text_classifier, ["HuggingFace is in", "This is another test"]
|
||||
|
||||
|
@ -143,7 +143,7 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
|
||||
],
|
||||
)
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer)
|
||||
return text_generator, ["This is a test", "Another test"]
|
||||
|
||||
|
@ -37,7 +37,7 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
||||
model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||
return token_classifier, ["A simple string", "A simple string that is quite a bit longer"]
|
||||
|
||||
|
@ -34,7 +34,7 @@ class TranslationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
|
||||
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
if isinstance(model.config, MBartConfig):
|
||||
src_lang, tgt_lang = list(tokenizer.lang_code_to_id.keys())[:2]
|
||||
translator = TranslationPipeline(model=model, tokenizer=tokenizer, src_lang=src_lang, tgt_lang=tgt_lang)
|
||||
|
@ -35,11 +35,11 @@ from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
class VideoClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
example_video_filepath = hf_hub_download(
|
||||
repo_id="nateraw/video-demo", filename="archery.mp4", repo_type="dataset"
|
||||
)
|
||||
video_classifier = VideoClassificationPipeline(model=model, feature_extractor=feature_extractor, top_k=2)
|
||||
video_classifier = VideoClassificationPipeline(model=model, feature_extractor=processor, top_k=2)
|
||||
examples = [
|
||||
example_video_filepath,
|
||||
"https://huggingface.co/datasets/nateraw/video-demo/resolve/main/archery.mp4",
|
||||
|
@ -36,7 +36,7 @@ else:
|
||||
class VisualQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
vqa_pipeline = pipeline("visual-question-answering", model="hf-internal-testing/tiny-vilt-random-vqa")
|
||||
examples = [
|
||||
{
|
||||
|
@ -30,7 +30,7 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
|
||||
model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
classifier = ZeroShotClassificationPipeline(
|
||||
model=model, tokenizer=tokenizer, candidate_labels=["polics", "health"]
|
||||
)
|
||||
|
@ -37,7 +37,7 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe
|
||||
# and only CLIP would be there for now.
|
||||
# model_mapping = {CLIPConfig: CLIPModel}
|
||||
|
||||
# def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
# def get_test_pipeline(self, model, tokenizer, processor):
|
||||
# if tokenizer is None:
|
||||
# # Side effect of no Fast Tokenizer class for these model, so skipping
|
||||
# # But the slow tokenizer test should still run as they're quite small
|
||||
@ -46,7 +46,7 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase, metaclass=Pipe
|
||||
# # return None, None
|
||||
|
||||
# image_classifier = ZeroShotImageClassificationPipeline(
|
||||
# model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
# model=model, tokenizer=tokenizer, feature_extractor=processor
|
||||
# )
|
||||
|
||||
# # test with a raw waveform
|
||||
|
@ -36,7 +36,7 @@ class ZeroShotObjectDetectionPipelineTests(unittest.TestCase, metaclass=Pipeline
|
||||
|
||||
model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
||||
|
||||
def get_test_pipeline(self, model, tokenizer, feature_extractor, image_processor):
|
||||
def get_test_pipeline(self, model, tokenizer, processor):
|
||||
object_detector = pipeline(
|
||||
"zero-shot-object-detection", model="hf-internal-testing/tiny-random-owlvit-object-detection"
|
||||
)
|
||||
|
@ -15,6 +15,7 @@
|
||||
|
||||
import argparse
|
||||
import collections.abc
|
||||
import copy
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
@ -31,6 +32,7 @@ from huggingface_hub import Repository, create_repo, upload_folder
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
IMAGE_PROCESSOR_MAPPING,
|
||||
PROCESSOR_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
AutoTokenizer,
|
||||
@ -74,29 +76,36 @@ def get_processor_types_from_config_class(config_class, allowed_mappings=None):
|
||||
|
||||
We use `tuple` here to include (potentially) both slow & fast tokenizers.
|
||||
"""
|
||||
|
||||
# To make a uniform return type
|
||||
def _to_tuple(x):
|
||||
if not isinstance(x, collections.abc.Sequence):
|
||||
x = (x,)
|
||||
else:
|
||||
x = tuple(x)
|
||||
return x
|
||||
|
||||
if allowed_mappings is None:
|
||||
allowed_mappings = ["processor", "tokenizer", "feature_extractor"]
|
||||
allowed_mappings = ["processor", "tokenizer", "image_processor", "feature_extractor"]
|
||||
|
||||
processor_types = ()
|
||||
|
||||
# Check first if a model has `ProcessorMixin`. Otherwise, check if it has tokenizers or a feature extractor.
|
||||
# Check first if a model has `ProcessorMixin`. Otherwise, check if it has tokenizers, and/or an image processor or
|
||||
# a feature extractor
|
||||
if config_class in PROCESSOR_MAPPING and "processor" in allowed_mappings:
|
||||
processor_types = PROCESSOR_MAPPING[config_class]
|
||||
elif config_class in TOKENIZER_MAPPING and "tokenizer" in allowed_mappings:
|
||||
processor_types = TOKENIZER_MAPPING[config_class]
|
||||
elif config_class in FEATURE_EXTRACTOR_MAPPING and "feature_extractor" in allowed_mappings:
|
||||
processor_types = FEATURE_EXTRACTOR_MAPPING[config_class]
|
||||
processor_types = _to_tuple(PROCESSOR_MAPPING[config_class])
|
||||
else:
|
||||
# Some configurations have no processor at all. For example, generic composite models like
|
||||
# `EncoderDecoderModel` is used for any (compatible) text models. Also, `DecisionTransformer` doesn't
|
||||
# require any processor.
|
||||
pass
|
||||
if config_class in TOKENIZER_MAPPING and "tokenizer" in allowed_mappings:
|
||||
processor_types = TOKENIZER_MAPPING[config_class]
|
||||
|
||||
# make a uniform return type
|
||||
if not isinstance(processor_types, collections.abc.Sequence):
|
||||
processor_types = (processor_types,)
|
||||
else:
|
||||
processor_types = tuple(processor_types)
|
||||
if config_class in IMAGE_PROCESSOR_MAPPING and "image_processor" in allowed_mappings:
|
||||
processor_types += _to_tuple(IMAGE_PROCESSOR_MAPPING[config_class])
|
||||
elif config_class in FEATURE_EXTRACTOR_MAPPING and "feature_extractor" in allowed_mappings:
|
||||
processor_types += _to_tuple(FEATURE_EXTRACTOR_MAPPING[config_class])
|
||||
|
||||
# Remark: some configurations have no processor at all. For example, generic composite models like
|
||||
# `EncoderDecoderModel` is used for any (compatible) text models. Also, `DecisionTransformer` doesn't
|
||||
# require any processor.
|
||||
|
||||
# We might get `None` for some tokenizers - remove them here.
|
||||
processor_types = tuple(p for p in processor_types if p is not None)
|
||||
@ -154,7 +163,7 @@ def get_config_class_from_processor_class(processor_class):
|
||||
return new_config_class
|
||||
|
||||
|
||||
def build_processor(config_class, processor_class):
|
||||
def build_processor(config_class, processor_class, allow_no_checkpoint=False):
|
||||
"""Create a processor for `processor_class`.
|
||||
|
||||
If a processor is not able to be built with the original arguments, this method tries to change the arguments and
|
||||
@ -264,6 +273,18 @@ def build_processor(config_class, processor_class):
|
||||
if config_class_from_processor_class != config_class:
|
||||
processor = build_processor(config_class_from_processor_class, processor_class)
|
||||
|
||||
# Try to create an image processor or a feature extractor without any checkpoint
|
||||
if (
|
||||
processor is None
|
||||
and allow_no_checkpoint
|
||||
and (issubclass(processor_class, BaseImageProcessor) or issubclass(processor_class, FeatureExtractionMixin))
|
||||
):
|
||||
try:
|
||||
processor = processor_class()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
pass
|
||||
|
||||
# validation
|
||||
if processor is not None:
|
||||
if not (isinstance(processor, processor_class) or processor_class.__name__.startswith("Auto")):
|
||||
@ -458,6 +479,18 @@ def convert_processors(processors, tiny_config, output_folder, result):
|
||||
result["warnings"].append(f"Failed to convert feature extractors: {e}")
|
||||
feature_extractors = []
|
||||
|
||||
if hasattr(tiny_config, "max_position_embeddings") and tiny_config.max_position_embeddings > 0:
|
||||
if fast_tokenizer is not None:
|
||||
if fast_tokenizer.__class__.__name__ in ["RobertaTokenizerFast", "XLMRobertaTokenizerFast"]:
|
||||
fast_tokenizer.model_max_length = tiny_config.max_position_embeddings - 2
|
||||
else:
|
||||
fast_tokenizer.model_max_length = tiny_config.max_position_embeddings
|
||||
if slow_tokenizer is not None:
|
||||
if slow_tokenizer.__class__.__name__ in ["RobertaTokenizer", "XLMRobertaTokenizer"]:
|
||||
slow_tokenizer.model_max_length = tiny_config.max_position_embeddings - 2
|
||||
else:
|
||||
slow_tokenizer.model_max_length = tiny_config.max_position_embeddings
|
||||
|
||||
processors = [fast_tokenizer, slow_tokenizer] + feature_extractors
|
||||
processors = [p for p in processors if p is not None]
|
||||
for p in processors:
|
||||
@ -491,6 +524,12 @@ def build_model(model_arch, tiny_config, output_dir):
|
||||
if os.path.isdir(processor_output_dir):
|
||||
shutil.copytree(processor_output_dir, checkpoint_dir, dirs_exist_ok=True)
|
||||
|
||||
tiny_config = copy.deepcopy(tiny_config)
|
||||
|
||||
if any([model_arch.__name__.endswith(x) for x in ["ForCausalLM", "LMHeadModel"]]):
|
||||
tiny_config.is_encoder_decoder = False
|
||||
tiny_config.is_decoder = True
|
||||
|
||||
model = model_arch(config=tiny_config)
|
||||
model.save_pretrained(checkpoint_dir)
|
||||
model.from_pretrained(checkpoint_dir)
|
||||
@ -819,7 +858,7 @@ def build(config_class, models_to_create, output_dir):
|
||||
|
||||
for processor_class in processor_classes:
|
||||
try:
|
||||
processor = build_processor(config_class, processor_class)
|
||||
processor = build_processor(config_class, processor_class, allow_no_checkpoint=True)
|
||||
if processor is not None:
|
||||
result["processor"][processor_class] = processor
|
||||
except Exception as e:
|
||||
|
Loading…
Reference in New Issue
Block a user