Enhance HfArgumentParser functionality and ease of use (#20323)

* Enhance HfArgumentParser

* Fix type hints for older python versions

* Fix and add tests (+formatting)

* Add changes

* doc-builder formatting

* Remove unused import "Call"
This commit is contained in:
Konstantin Dobler 2022-11-21 18:33:37 +01:00 committed by GitHub
parent 96783e53b4
commit 1e3f17b5ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 204 additions and 23 deletions

View File

@ -20,11 +20,19 @@ from copy import copy
from enum import Enum
from inspect import isclass
from pathlib import Path
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints
from typing import Any, Callable, Dict, Iterable, List, NewType, Optional, Tuple, Union, get_type_hints
import yaml
try:
# For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/
from typing import Literal
except ImportError:
# For Python 3.7
from typing_extensions import Literal
DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
@ -43,6 +51,68 @@ def string_to_bool(v):
)
def make_choice_type_function(choices: list) -> Callable[[str], Any]:
"""
Creates a mapping function from each choices string representation to the actual value. Used to support multiple
value types for a single argument.
Args:
choices (list): List of choices.
Returns:
Callable[[str], Any]: Mapping function from string representation to actual value for each choice.
"""
str_to_choice = {str(choice): choice for choice in choices}
return lambda arg: str_to_choice.get(arg, arg)
def HfArg(
*,
aliases: Union[str, List[str]] = None,
help: str = None,
default: Any = dataclasses.MISSING,
default_factory: Callable[[], Any] = dataclasses.MISSING,
metadata: dict = None,
**kwargs,
) -> dataclasses.Field:
"""Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`.
Example comparing the use of `HfArg` and `dataclasses.field`:
```
@dataclass
class Args:
regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
```
Args:
aliases (Union[str, List[str]], optional):
Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
Defaults to None.
help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
default (Any, optional):
Default value for the argument. If not default or default_factory is specified, the argument is required.
Defaults to dataclasses.MISSING.
default_factory (Callable[[], Any], optional):
The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
Defaults to dataclasses.MISSING.
metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
Returns:
Field: A `dataclasses.Field` with the desired properties.
"""
if metadata is None:
# Important, don't use as default param in function signature because dict is mutable and shared across function calls
metadata = {}
if aliases is not None:
metadata["aliases"] = aliases
if help is not None:
metadata["help"] = help
return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
class HfArgumentParser(ArgumentParser):
"""
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
@ -84,6 +154,10 @@ class HfArgumentParser(ArgumentParser):
"`typing.get_type_hints` method by default"
)
aliases = kwargs.pop("aliases", [])
if isinstance(aliases, str):
aliases = [aliases]
origin_type = getattr(field.type, "__origin__", field.type)
if origin_type is Union:
if str not in field.type.__args__ and (
@ -108,9 +182,14 @@ class HfArgumentParser(ArgumentParser):
# A variable to store kwargs for a boolean field, if needed
# so that we can init a `no_*` complement argument (see below)
bool_kwargs = {}
if isinstance(field.type, type) and issubclass(field.type, Enum):
kwargs["choices"] = [x.value for x in field.type]
kwargs["type"] = type(kwargs["choices"][0])
if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)):
if origin_type is Literal:
kwargs["choices"] = field.type.__args__
else:
kwargs["choices"] = [x.value for x in field.type]
kwargs["type"] = make_choice_type_function(kwargs["choices"])
if field.default is not dataclasses.MISSING:
kwargs["default"] = field.default
else:
@ -146,7 +225,7 @@ class HfArgumentParser(ArgumentParser):
kwargs["default"] = field.default_factory()
else:
kwargs["required"] = True
parser.add_argument(field_name, **kwargs)
parser.add_argument(field_name, *aliases, **kwargs)
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
# Order is important for arguments with the same destination!
@ -178,7 +257,12 @@ class HfArgumentParser(ArgumentParser):
self._parse_dataclass_field(parser, field)
def parse_args_into_dataclasses(
self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None
self,
args=None,
return_remaining_strings=False,
look_for_args_file=True,
args_filename=None,
args_file_flag=None,
) -> Tuple[DataClass, ...]:
"""
Parse command-line args into instances of the specified dataclass types.
@ -196,6 +280,9 @@ class HfArgumentParser(ArgumentParser):
process, and will append its potential content to the command line args.
args_filename:
If not None, will uses this file instead of the ".args" file specified in the previous argument.
args_file_flag:
If not None, will look for a file in the command-line args specified with this flag. The flag can be
specified multiple times and precedence is determined by the order (last one wins).
Returns:
Tuple consisting of:
@ -205,17 +292,36 @@ class HfArgumentParser(ArgumentParser):
after initialization.
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
"""
if args_filename or (look_for_args_file and len(sys.argv)):
if args_filename:
args_file = Path(args_filename)
else:
args_file = Path(sys.argv[0]).with_suffix(".args")
if args_file.exists():
fargs = args_file.read_text().split()
args = fargs + args if args is not None else fargs + sys.argv[1:]
# in case of duplicate arguments the first one has precedence
# so we append rather than prepend.
if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)):
args_files = []
if args_filename:
args_files.append(Path(args_filename))
elif look_for_args_file and len(sys.argv):
args_files.append(Path(sys.argv[0]).with_suffix(".args"))
# args files specified via command line flag should overwrite default args files so we add them last
if args_file_flag:
# Create special parser just to extract the args_file_flag values
args_file_parser = ArgumentParser()
args_file_parser.add_argument(args_file_flag, type=str, action="append")
# Use only remaining args for further parsing (remove the args_file_flag)
cfg, args = args_file_parser.parse_known_args(args=args)
cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None)
if cmd_args_file_paths:
args_files.extend([Path(p) for p in cmd_args_file_paths])
file_args = []
for args_file in args_files:
if args_file.exists():
file_args += args_file.read_text().split()
# in case of duplicate arguments the last one has precedence
# args specified via the command line should overwrite args from files, so we add them last
args = file_args + args if args is not None else file_args + sys.argv[1:]
namespace, remaining_args = self.parse_known_args(args=args)
outputs = []
for dtype in self.dataclass_types:

View File

@ -25,7 +25,15 @@ from typing import List, Optional
import yaml
from transformers import HfArgumentParser, TrainingArguments
from transformers.hf_argparser import string_to_bool
from transformers.hf_argparser import make_choice_type_function, string_to_bool
try:
# For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/
from typing import Literal
except ImportError:
# For Python 3.7
from typing_extensions import Literal
def list_field(default=None, metadata=None):
@ -58,6 +66,12 @@ class BasicEnum(Enum):
toto = "toto"
class MixedTypeEnum(Enum):
titi = "titi"
toto = "toto"
fourtytwo = 42
@dataclass
class EnumExample:
foo: BasicEnum = "toto"
@ -66,6 +80,14 @@ class EnumExample:
self.foo = BasicEnum(self.foo)
@dataclass
class MixedTypeEnumExample:
foo: MixedTypeEnum = "toto"
def __post_init__(self):
self.foo = MixedTypeEnum(self.foo)
@dataclass
class OptionalExample:
foo: Optional[int] = None
@ -111,6 +133,14 @@ class HfArgumentParserTest(unittest.TestCase):
for x, y in zip(a._actions, b._actions):
xx = {k: v for k, v in vars(x).items() if k != "container"}
yy = {k: v for k, v in vars(y).items() if k != "container"}
# Choices with mixed type have custom function as "type"
# So we need to compare results directly for equality
if xx.get("choices", None) and yy.get("choices", None):
for expected_choice in yy["choices"] + xx["choices"]:
self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice))
del xx["type"], yy["type"]
self.assertEqual(xx, yy)
def test_basic(self):
@ -163,21 +193,56 @@ class HfArgumentParserTest(unittest.TestCase):
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
def test_with_enum(self):
parser = HfArgumentParser(EnumExample)
parser = HfArgumentParser(MixedTypeEnumExample)
expected = argparse.ArgumentParser()
expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str)
expected.add_argument(
"--foo",
default="toto",
choices=["titi", "toto", 42],
type=make_choice_type_function(["titi", "toto", 42]),
)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args.foo, "toto")
enum_ex = parser.parse_args_into_dataclasses([])[0]
self.assertEqual(enum_ex.foo, BasicEnum.toto)
self.assertEqual(enum_ex.foo, MixedTypeEnum.toto)
args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, "titi")
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
self.assertEqual(enum_ex.foo, BasicEnum.titi)
self.assertEqual(enum_ex.foo, MixedTypeEnum.titi)
args = parser.parse_args(["--foo", "42"])
self.assertEqual(args.foo, 42)
enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
def test_with_literal(self):
@dataclass
class LiteralExample:
foo: Literal["titi", "toto", 42] = "toto"
parser = HfArgumentParser(LiteralExample)
expected = argparse.ArgumentParser()
expected.add_argument(
"--foo",
default="toto",
choices=("titi", "toto", 42),
type=make_choice_type_function(["titi", "toto", 42]),
)
self.argparsersEqual(parser, expected)
args = parser.parse_args([])
self.assertEqual(args.foo, "toto")
args = parser.parse_args(["--foo", "titi"])
self.assertEqual(args.foo, "titi")
args = parser.parse_args(["--foo", "42"])
self.assertEqual(args.foo, 42)
def test_with_list(self):
parser = HfArgumentParser(ListExample)
@ -222,7 +287,12 @@ class HfArgumentParserTest(unittest.TestCase):
expected = argparse.ArgumentParser()
expected.add_argument("--required_list", nargs="+", type=int, required=True)
expected.add_argument("--required_str", type=str, required=True)
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
expected.add_argument(
"--required_enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
)
self.argparsersEqual(parser, expected)
def test_with_string_literal_annotation(self):
@ -230,7 +300,12 @@ class HfArgumentParserTest(unittest.TestCase):
expected = argparse.ArgumentParser()
expected.add_argument("--foo", type=int, required=True)
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
expected.add_argument(
"--required_enum",
type=make_choice_type_function(["titi", "toto"]),
choices=["titi", "toto"],
required=True,
)
expected.add_argument("--opt", type=string_to_bool, default=None)
expected.add_argument("--baz", default="toto", type=str, help="help message")
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)