mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix #3305: run_ner only possible on ModelForTokenClassification models
This commit is contained in:
parent
0c44b11917
commit
656e1386a2
@ -31,7 +31,6 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (
|
||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
@ -39,7 +38,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers.modeling_auto import MODEL_MAPPING
|
||||
from transformers.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
|
||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||
|
||||
|
||||
@ -51,8 +50,9 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING)
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ())
|
||||
|
||||
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
|
||||
|
||||
@ -384,7 +384,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
|
Loading…
Reference in New Issue
Block a user