mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Allow --arg Value for booleans in HfArgumentParser (#9823)
* Allow --arg Value for booleans in HfArgumentParser * Update last test * Better error message
This commit is contained in:
parent
35d55b7b84
commit
893120facc
@ -15,7 +15,7 @@
|
||||
import dataclasses
|
||||
import json
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from argparse import ArgumentParser, ArgumentTypeError
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
|
||||
@ -25,6 +25,20 @@ DataClass = NewType("DataClass", Any)
|
||||
DataClassType = NewType("DataClassType", Any)
|
||||
|
||||
|
||||
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
||||
def string_to_bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise ArgumentTypeError(
|
||||
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
|
||||
)
|
||||
|
||||
|
||||
class HfArgumentParser(ArgumentParser):
|
||||
"""
|
||||
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
|
||||
@ -85,11 +99,20 @@ class HfArgumentParser(ArgumentParser):
|
||||
if field.default is not dataclasses.MISSING:
|
||||
kwargs["default"] = field.default
|
||||
elif field.type is bool or field.type is Optional[bool]:
|
||||
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
||||
kwargs["action"] = "store_false" if field.default is True else "store_true"
|
||||
if field.default is True:
|
||||
field_name = f"--no_{field.name}"
|
||||
kwargs["dest"] = field.name
|
||||
self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs)
|
||||
|
||||
# Hack because type=bool in argparse does not behave as we want.
|
||||
kwargs["type"] = string_to_bool
|
||||
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
||||
# Default value is True if we have no default when of type bool.
|
||||
default = True if field.default is dataclasses.MISSING else field.default
|
||||
# This is the value that will get picked if we don't include --field_name in any way
|
||||
kwargs["default"] = default
|
||||
# This tells argparse we accept 0 or 1 value after --field_name
|
||||
kwargs["nargs"] = "?"
|
||||
# This is the value that will get picked if we do --field_name (without value)
|
||||
kwargs["const"] = True
|
||||
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
|
||||
kwargs["nargs"] = "+"
|
||||
kwargs["type"] = field.type.__args__[0]
|
||||
|
@ -20,6 +20,7 @@ from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from transformers.hf_argparser import string_to_bool
|
||||
|
||||
|
||||
def list_field(default=None, metadata=None):
|
||||
@ -44,6 +45,7 @@ class WithDefaultExample:
|
||||
class WithDefaultBoolExample:
|
||||
foo: bool = False
|
||||
baz: bool = True
|
||||
opt: Optional[bool] = None
|
||||
|
||||
|
||||
class BasicEnum(Enum):
|
||||
@ -91,7 +93,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
expected.add_argument("--foo", type=int, required=True)
|
||||
expected.add_argument("--bar", type=float, required=True)
|
||||
expected.add_argument("--baz", type=str, required=True)
|
||||
expected.add_argument("--flag", action="store_true")
|
||||
expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?")
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_with_default(self):
|
||||
@ -106,15 +108,26 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
parser = HfArgumentParser(WithDefaultBoolExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", action="store_true")
|
||||
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
|
||||
expected.add_argument("--no_baz", action="store_false", dest="baz")
|
||||
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
|
||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args, Namespace(foo=False, baz=True))
|
||||
self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "--no_baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=False))
|
||||
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "--baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))
|
||||
|
||||
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
|
||||
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
|
||||
|
||||
def test_with_enum(self):
|
||||
parser = HfArgumentParser(EnumExample)
|
||||
|
Loading…
Reference in New Issue
Block a user