Allow --arg Value for booleans in HfArgumentParser (#9823)

* Allow --arg Value for booleans in HfArgumentParser

* Update last test

* Better error message
This commit is contained in:
Sylvain Gugger 2021-01-27 09:31:42 -05:00 committed by GitHub
parent 35d55b7b84
commit 893120facc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 9 deletions

View File

@ -15,7 +15,7 @@
import dataclasses import dataclasses
import json import json
import sys import sys
from argparse import ArgumentParser from argparse import ArgumentParser, ArgumentTypeError
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
@ -25,6 +25,20 @@ DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any) DataClassType = NewType("DataClassType", Any)
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
def string_to_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise ArgumentTypeError(
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
)
class HfArgumentParser(ArgumentParser): class HfArgumentParser(ArgumentParser):
""" """
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
@ -85,11 +99,20 @@ class HfArgumentParser(ArgumentParser):
if field.default is not dataclasses.MISSING: if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default kwargs["default"] = field.default
elif field.type is bool or field.type is Optional[bool]: elif field.type is bool or field.type is Optional[bool]:
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
kwargs["action"] = "store_false" if field.default is True else "store_true"
if field.default is True: if field.default is True:
field_name = f"--no_{field.name}" self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs)
kwargs["dest"] = field.name
# Hack because type=bool in argparse does not behave as we want.
kwargs["type"] = string_to_bool
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
# Default value is True if we have no default when of type bool.
default = True if field.default is dataclasses.MISSING else field.default
# This is the value that will get picked if we don't include --field_name in any way
kwargs["default"] = default
# This tells argparse we accept 0 or 1 value after --field_name
kwargs["nargs"] = "?"
# This is the value that will get picked if we do --field_name (without value)
kwargs["const"] = True
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List): elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
kwargs["nargs"] = "+" kwargs["nargs"] = "+"
kwargs["type"] = field.type.__args__[0] kwargs["type"] = field.type.__args__[0]

View File

@ -20,6 +20,7 @@ from enum import Enum
from typing import List, Optional from typing import List, Optional
from transformers import HfArgumentParser, TrainingArguments from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import string_to_bool
def list_field(default=None, metadata=None): def list_field(default=None, metadata=None):
@ -44,6 +45,7 @@ class WithDefaultExample:
class WithDefaultBoolExample: class WithDefaultBoolExample:
foo: bool = False foo: bool = False
baz: bool = True baz: bool = True
opt: Optional[bool] = None
class BasicEnum(Enum): class BasicEnum(Enum):
@ -91,7 +93,7 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--foo", type=int, required=True) expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--bar", type=float, required=True) expected.add_argument("--bar", type=float, required=True)
expected.add_argument("--baz", type=str, required=True) expected.add_argument("--baz", type=str, required=True)
expected.add_argument("--flag", action="store_true") expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?")
self.argparsersEqual(parser, expected) self.argparsersEqual(parser, expected)
def test_with_default(self): def test_with_default(self):
@ -106,15 +108,26 @@ class HfArgumentParserTest(unittest.TestCase):
parser = HfArgumentParser(WithDefaultBoolExample) parser = HfArgumentParser(WithDefaultBoolExample)
expected = argparse.ArgumentParser() expected = argparse.ArgumentParser()
expected.add_argument("--foo", action="store_true") expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
expected.add_argument("--no_baz", action="store_false", dest="baz") expected.add_argument("--no_baz", action="store_false", dest="baz")
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
expected.add_argument("--opt", type=string_to_bool, default=None)
self.argparsersEqual(parser, expected) self.argparsersEqual(parser, expected)
args = parser.parse_args([]) args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=False, baz=True)) self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
args = parser.parse_args(["--foo", "--no_baz"]) args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False)) self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--baz"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
def test_with_enum(self): def test_with_enum(self):
parser = HfArgumentParser(EnumExample) parser = HfArgumentParser(EnumExample)