mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
HfArgumentParser: allow for hyhenated field names in long-options (#33990)
Allow for hyphenated field names in long-options argparse converts hyphens into underscores before assignment (e.g., an option passed as `--long-option` will be stored under `long_option`), So there is no need to pass options as literal attributes, as in `--long_option` (with an underscore instead of a hyphen). This commit ensures that this behavior is respected by `parse_args_into_dataclasses` as well. Issue: #33933 Co-authored-by: Daniel Marti <mrtidm@amazon.com>
This commit is contained in:
parent
adea67541a
commit
a84c413773
@ -138,7 +138,14 @@ class HfArgumentParser(ArgumentParser):
|
||||
|
||||
@staticmethod
|
||||
def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
|
||||
field_name = f"--{field.name}"
|
||||
# Long-option strings are conventionlly separated by hyphens rather
|
||||
# than underscores, e.g., "--long-format" rather than "--long_format".
|
||||
# Argparse converts hyphens to underscores so that the destination
|
||||
# string is a valid attribute name. Hf_argparser should do the same.
|
||||
long_options = [f"--{field.name}"]
|
||||
if "_" in field.name:
|
||||
long_options.append(f"--{field.name.replace('_', '-')}")
|
||||
|
||||
kwargs = field.metadata.copy()
|
||||
# field.metadata is not used at all by Data Classes,
|
||||
# it is provided as a third-party extension mechanism.
|
||||
@ -198,11 +205,11 @@ class HfArgumentParser(ArgumentParser):
|
||||
if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
|
||||
# Default value is False if we have no default when of type bool.
|
||||
default = False if field.default is dataclasses.MISSING else field.default
|
||||
# This is the value that will get picked if we don't include --field_name in any way
|
||||
# This is the value that will get picked if we don't include --{field.name} in any way
|
||||
kwargs["default"] = default
|
||||
# This tells argparse we accept 0 or 1 value after --field_name
|
||||
# This tells argparse we accept 0 or 1 value after --{field.name}
|
||||
kwargs["nargs"] = "?"
|
||||
# This is the value that will get picked if we do --field_name (without value)
|
||||
# This is the value that will get picked if we do --{field.name} (without value)
|
||||
kwargs["const"] = True
|
||||
elif isclass(origin_type) and issubclass(origin_type, list):
|
||||
kwargs["type"] = field.type.__args__[0]
|
||||
@ -219,7 +226,7 @@ class HfArgumentParser(ArgumentParser):
|
||||
kwargs["default"] = field.default_factory()
|
||||
else:
|
||||
kwargs["required"] = True
|
||||
parser.add_argument(field_name, *aliases, **kwargs)
|
||||
parser.add_argument(*long_options, *aliases, **kwargs)
|
||||
|
||||
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
|
||||
# Order is important for arguments with the same destination!
|
||||
@ -227,7 +234,13 @@ class HfArgumentParser(ArgumentParser):
|
||||
# here and we do not need those changes/additional keys.
|
||||
if field.default is True and (field.type is bool or field.type == Optional[bool]):
|
||||
bool_kwargs["default"] = False
|
||||
parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs)
|
||||
parser.add_argument(
|
||||
f"--no_{field.name}",
|
||||
f"--no-{field.name.replace('_', '-')}",
|
||||
action="store_false",
|
||||
dest=field.name,
|
||||
**bool_kwargs,
|
||||
)
|
||||
|
||||
def _add_dataclass_arguments(self, dtype: DataClassType):
|
||||
if hasattr(dtype, "_argument_group_name"):
|
||||
|
@ -189,7 +189,7 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
expected.add_argument("--baz", type=string_to_bool, default=True, const=True, nargs="?")
|
||||
# A boolean no_* argument always has to come after its "default: True" regular counter-part
|
||||
# and its default must be set to False
|
||||
expected.add_argument("--no_baz", action="store_false", default=False, dest="baz")
|
||||
expected.add_argument("--no_baz", "--no-baz", action="store_false", default=False, dest="baz")
|
||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||
|
||||
dataclass_types = [WithDefaultBoolExample]
|
||||
@ -206,6 +206,9 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = parser.parse_args(["--foo", "--no_baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "--no-baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=False, opt=None))
|
||||
|
||||
args = parser.parse_args(["--foo", "--baz"])
|
||||
self.assertEqual(args, Namespace(foo=True, baz=True, opt=None))
|
||||
|
||||
@ -271,10 +274,10 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
parser = HfArgumentParser(ListExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo_int", nargs="+", default=[], type=int)
|
||||
expected.add_argument("--bar_int", nargs="+", default=[1, 2, 3], type=int)
|
||||
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
expected.add_argument("--foo_float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
|
||||
expected.add_argument("--foo_int", "--foo-int", nargs="+", default=[], type=int)
|
||||
expected.add_argument("--bar_int", "--bar-int", nargs="+", default=[1, 2, 3], type=int)
|
||||
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
expected.add_argument("--foo_float", "--foo-float", nargs="+", default=[0.1, 0.2, 0.3], type=float)
|
||||
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
@ -287,6 +290,9 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
args = parser.parse_args("--foo_int 1 --bar_int 2 3 --foo_str a b c --foo_float 0.1 0.7".split())
|
||||
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
|
||||
|
||||
args = parser.parse_args("--foo-int 1 --bar-int 2 3 --foo-str a b c --foo-float 0.1 0.7".split())
|
||||
self.assertEqual(args, Namespace(foo_int=[1], bar_int=[2, 3], foo_str=["a", "b", "c"], foo_float=[0.1, 0.7]))
|
||||
|
||||
def test_with_optional(self):
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", default=None, type=int)
|
||||
@ -314,10 +320,11 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
parser = HfArgumentParser(RequiredExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--required_list", nargs="+", type=int, required=True)
|
||||
expected.add_argument("--required_str", type=str, required=True)
|
||||
expected.add_argument("--required_list", "--required-list", nargs="+", type=int, required=True)
|
||||
expected.add_argument("--required_str", "--required-str", type=str, required=True)
|
||||
expected.add_argument(
|
||||
"--required_enum",
|
||||
"--required-enum",
|
||||
type=make_choice_type_function(["titi", "toto"]),
|
||||
choices=["titi", "toto"],
|
||||
required=True,
|
||||
@ -331,13 +338,14 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
expected.add_argument("--foo", type=int, required=True)
|
||||
expected.add_argument(
|
||||
"--required_enum",
|
||||
"--required-enum",
|
||||
type=make_choice_type_function(["titi", "toto"]),
|
||||
choices=["titi", "toto"],
|
||||
required=True,
|
||||
)
|
||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
||||
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
expected.add_argument("--foo_str", "--foo-str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_parse_dict(self):
|
||||
|
Loading…
Reference in New Issue
Block a user