HfArgumentParser: allow for hyhenated field names in long-options (#33990)

Allow for hyphenated field names in long-options

argparse converts hyphens into underscores before assignment (e.g., an
option passed as `--long-option` will be stored under `long_option`), So
there is no need to pass options as literal attributes, as in
`--long_option` (with an underscore instead of a hyphen). This commit
ensures that this behavior is respected by `parse_args_into_dataclasses`
as well.

Issue: #33933

Co-authored-by: Daniel Marti <mrtidm@amazon.com>
This commit is contained in:
Dani Martí 2024-10-10 02:58:26 -07:00 committed by GitHub
parent adea67541a
commit a84c413773
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 14 deletions

View File

@ -138,7 +138,14 @@ class HfArgumentParser(ArgumentParser):
@staticmethod
def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
field_name = f"--{field.name}"
# Long-option strings are conventionlly separated by hyphens rather
# than underscores, e.g., "--long-format" rather than "--long_format".
# Argparse converts hyphens to underscores so that the destination
# string is a valid attribute name. Hf_argparser should do the same.
long_options = [f"--{field.name}"]
if "_" in field.name:
long_options.append(f"--{field.name.replace('_', '-')}")
kwargs = field.metadata.copy()
# field.metadata is not used at all by Data Classes,
# it is provided as a third-party extension mechanism.
@ -198,11 +205,11 @@ class HfArgumentParser(ArgumentParser):
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
# Default value is False if we have no default when of type bool.
default = False if field.default is dataclasses.MISSING else field.default
# This is the value that will get picked if we don't include --field_name in any way
# This is the value that will get picked if we don't include --{field.name} in any way
kwargs["default"] = default
# This tells argparse we accept 0 or 1 value after --field_name
# This tells argparse we accept 0 or 1 value after --{field.name}
kwargs["nargs"] = "?"
# This is the value that will get picked if we do --field_name (without value)
# This is the value that will get picked if we do --{field.name} (without value)
kwargs["const"] = True
elif isclass(origin_type) and issubclass(origin_type, list):
kwargs["type"] = field.type.__args__[0]
@ -219,7 +226,7 @@ class HfArgumentParser(ArgumentParser):
kwargs["default"] = field.default_factory()
else:
kwargs["required"] = True
parser.add_argument(field_name, *aliases, **kwargs)
parser.add_argument(*long_options, *aliases, **kwargs)
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
# Order is important for arguments with the same destination!
@ -227,7 +234,13 @@ class HfArgumentParser(ArgumentParser):
# here and we do not need those changes/additional keys.
if field.default is True and (field.type is bool or field.type == Optional[bool]):
bool_kwargs["default"] = False
parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs)
parser.add_argument(
f"--no_{field.name}",
f"--no-{field.name.replace('_', '-')}",
action="store_false",
dest=field.name,
**bool_kwargs,
)
def _add_dataclass_arguments(self, dtype: DataClassType):
if hasattr(dtype, "_argument_group_name"):

View File

@ -189,7 +189,7 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
# A boolean no_* argument always has to come after its "default: True" regular counter-part
# and its default must be set to False
expected.add_argument("--no_baz", action="store_false", default=False, dest="baz")
expected.add_argument("--no_baz", "--no-baz", action="store_false", default=False, dest="baz")
expected.add_argument("--opt", type=string_to_bool, default=None)
dataclass_types = [WithDefaultBoolExample]
@ -206,6 +206,9 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args(["--foo", "--no_baz"])
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--no-baz"])
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
args = parser.parse_args(["--foo", "--baz"])
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
@ -271,10 +274,10 @@ class HfArgumentParserTest(unittest.TestCase):
parser = HfArgumentParser(ListExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo_int", nargs="+", default=[], type=int)
expected.add_argument("--bar_int", nargs="+", default=[1, 2, 3], type=int)
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
expected.add_argument("--foo_float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
expected.add_argument("--foo_int", "--foo-int", nargs="+", default=[], type=int)
expected.add_argument("--bar_int", "--bar-int", nargs="+", default=[1, 2, 3], type=int)
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
expected.add_argument("--foo_float", "--foo-float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
self.argparsersEqual(parser, expected)
@ -287,6 +290,9 @@ class HfArgumentParserTest(unittest.TestCase):
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split())
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
def test_with_optional(self):
expected = argparse.ArgumentParser()
expected.add_argument("--foo", default=None, type=int)
@ -314,10 +320,11 @@ class HfArgumentParserTest(unittest.TestCase):
parser = HfArgumentParser(RequiredExample)
expected = argparse.ArgumentParser()
expected.add_argument("--required_list", nargs="+", type=int, required=True)
expected.add_argument("--required_str", type=str, required=True)
expected.add_argument("--required_list", "--required-list", nargs="+", type=int, required=True)
expected.add_argument("--required_str", "--required-str", type=str, required=True)
expected.add_argument(
"--required_enum",
"--required-enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
@ -331,13 +338,14 @@ class HfArgumentParserTest(unittest.TestCase):
expected.add_argument("--foo", type=int, required=True)
expected.add_argument(
"--required_enum",
"--required-enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
)
expected.add_argument("--opt", type=string_to_bool, default=None)
expected.add_argument("--baz", default="toto", type=str, help="help message")
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
self.argparsersEqual(parser, expected)
def test_parse_dict(self):