Fix sneaky torch dependency in TF example (#22804)

This commit is contained in:
Matt 2023-04-17 16:11:52 +01:00 committed by GitHub
parent 626c1b8af1
commit 2237127a6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -34,7 +34,7 @@ from PIL import Image
import transformers
from transformers import (
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
AutoConfig,
AutoImageProcessor,
DefaultDataCollator,
@ -58,7 +58,7 @@ check_min_version("4.29.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@ -262,11 +262,6 @@ def main():
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")
# region Dataset and labels