fix dataset run_object_detection

This commit is contained in:
Quentin Lhoest 2025-06-26 16:56:12 +02:00 committed by ydshieh
parent 3457e8e73e
commit c1223639b9
4 changed files with 4 additions and 4 deletions

View File

@ -399,7 +399,7 @@ def main():
dataset["validation"] = split["test"]
# Get dataset categories and prepare mappings for label_name <-> label_id
categories = dataset["train"].features["objects"].feature["category"].names
categories = dataset["train"].features["objects"]["category"].feature.names
id2label = dict(enumerate(categories))
label2id = {v: k for k, v in id2label.items()}

View File

@ -460,7 +460,7 @@ def main():
dataset["validation"] = split["test"]
# Get dataset categories and prepare mappings for label_name <-> label_id
categories = dataset["train"].features["objects"].feature["category"].names
categories = dataset["train"].features["objects"]["category"].feature.names
id2label = dict(enumerate(categories))
label2id = {v: k for k, v in id2label.items()}

View File

@ -341,7 +341,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
testargs = f"""
{self.examples_dir}/pytorch/object-detection/run_object_detection_no_trainer.py
--model_name_or_path qubvel-hf/detr-resnet-50-finetuned-10k-cppe5
--dataset_name qubvel-hf/cppe-5-sample
--dataset_name hf-internal-testing/cppe-5-sample
--output_dir {tmp_dir}
--max_train_steps=10
--num_warmup_steps=2

View File

@ -620,7 +620,7 @@ class ExamplesTests(TestCasePlus):
run_object_detection.py
--model_name_or_path qubvel-hf/detr-resnet-50-finetuned-10k-cppe5
--output_dir {tmp_dir}
--dataset_name qubvel-hf/cppe-5-sample
--dataset_name hf-internal-testing/cppe-5-sample
--do_train
--do_eval
--remove_unused_columns False