Allow --arg Value for booleans in HfArgumentParser (#9823)

* Allow --arg Value for booleans in HfArgumentParser

* Update last test

* Better error message
This commit is contained in:
Sylvain Gugger 2021-01-27 09:31:42 -05:00 committed by GitHub
parent 35d55b7b84
commit 893120facc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 45 additions and 9 deletions

View File

@ -15,7 +15,7 @@
import dataclasses
import json
import sys
from argparse import ArgumentParser
from argparse import ArgumentParser, ArgumentTypeError
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, List, NewType, Optional, Tuple, Union
@ -25,6 +25,20 @@ DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
# From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
def string_to_bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise ArgumentTypeError(
f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
)
class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
@ -85,11 +99,20 @@ class HfArgumentParser(ArgumentParser):
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
elif field.type is bool or field.type is Optional[bool]:
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
kwargs["action"] = "store_false" if field.default is True else "store_true"
if field.default is True:
field_name = f"--no_{field.name}"
kwargs["dest"] = field.name
self.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **kwargs)
# Hack because type=bool in argparse does not behave as we want.
kwargs["type"] = string_to_bool
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
# Default value is True if we have no default when of type bool.
default = True if field.default is dataclasses.MISSING else field.default
# This is the value that will get picked if we don't include --field_name in any way
kwargs["default"] = default
# This tells argparse we accept 0 or 1 value after --field_name
kwargs["nargs"] = "?"
# This is the value that will get picked if we do --field_name (without value)
kwargs["const"] = True
elif hasattr(field.type, "__origin__") and issubclass(field.type.__origin__, List):
kwargs["nargs"] = "+"
kwargs["type"] = field.type.__args__[0]

View File

@ -20,6 +20,7 @@ from enum import Enum
from typing import List, Optional
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import string_to_bool
def list_field(default=None, metadata=None):
@ -44,6 +45,7 @@ class WithDefaultExample:
class WithDefaultBoolExample:
foo: bool = False
baz: bool = True
opt: Optional[bool] = None
class BasicEnum(Enum):
@ -91,7 +93,7 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--bar", type=float, required=True)
expected.add_argument("--baz", type=str, required=True)
expected.add_argument("--flag", action="store_true")
expected.add_argument("--flag", type=string_to_bool, default=True, const=True, nargs="?")
self.argparsersEqual(parser, expected)
def test_with_default(self):
@ -106,15 +108,26 @@ class HfArgumentParserTest(unittest.TestCase):
parser = HfArgumentParser(WithDefaultBoolExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", action="store_true")
expected.add_argument("--foo", type=string_to_bool, default=False, const=True, nargs="?")
expected.add_argument("--no_baz", action="store_false", dest="baz")
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
expected.add_argument("--opt", type=string_to_bool, default=None)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args, Namespace(foo=False, baz=True))
self.assertEqual(args, Namespace(foo=False, baz=True, opt=None))
args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False))
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--baz"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
args = parser.parse_args(["--foo", "True", "--baz", "True", "--opt", "True"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=True))
args = parser.parse_args(["--foo", "False", "--baz", "False", "--opt", "False"])
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
def test_with_enum(self):
parser = HfArgumentParser(EnumExample)