mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enhance HfArgumentParser functionality and ease of use (#20323)
* Enhance HfArgumentParser * Fix type hints for older python versions * Fix and add tests (+formatting) * Add changes * doc-builder formatting * Remove unused import "Call"
This commit is contained in:
parent
96783e53b4
commit
1e3f17b5ab
@ -20,11 +20,19 @@ from copy import copy
|
||||
from enum import Enum
|
||||
from inspect import isclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints
|
||||
from typing import Any, Callable, Dict, Iterable, List, NewType, Optional, Tuple, Union, get_type_hints
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
try:
|
||||
# For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
# For Python 3.7
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
DataClass = NewType("DataClass", Any)
|
||||
DataClassType = NewType("DataClassType", Any)
|
||||
|
||||
@ -43,6 +51,68 @@ def string_to_bool(v):
|
||||
)
|
||||
|
||||
|
||||
def make_choice_type_function(choices: list) -> Callable[[str], Any]:
|
||||
"""
|
||||
Creates a mapping function from each choices string representation to the actual value. Used to support multiple
|
||||
value types for a single argument.
|
||||
|
||||
Args:
|
||||
choices (list): List of choices.
|
||||
|
||||
Returns:
|
||||
Callable[[str], Any]: Mapping function from string representation to actual value for each choice.
|
||||
"""
|
||||
str_to_choice = {str(choice): choice for choice in choices}
|
||||
return lambda arg: str_to_choice.get(arg, arg)
|
||||
|
||||
|
||||
def HfArg(
|
||||
*,
|
||||
aliases: Union[str, List[str]] = None,
|
||||
help: str = None,
|
||||
default: Any = dataclasses.MISSING,
|
||||
default_factory: Callable[[], Any] = dataclasses.MISSING,
|
||||
metadata: dict = None,
|
||||
**kwargs,
|
||||
) -> dataclasses.Field:
|
||||
"""Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`.
|
||||
|
||||
Example comparing the use of `HfArg` and `dataclasses.field`:
|
||||
```
|
||||
@dataclass
|
||||
class Args:
|
||||
regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
|
||||
hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
|
||||
```
|
||||
|
||||
Args:
|
||||
aliases (Union[str, List[str]], optional):
|
||||
Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
|
||||
Defaults to None.
|
||||
help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
|
||||
default (Any, optional):
|
||||
Default value for the argument. If not default or default_factory is specified, the argument is required.
|
||||
Defaults to dataclasses.MISSING.
|
||||
default_factory (Callable[[], Any], optional):
|
||||
The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
|
||||
default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
|
||||
Defaults to dataclasses.MISSING.
|
||||
metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Field: A `dataclasses.Field` with the desired properties.
|
||||
"""
|
||||
if metadata is None:
|
||||
# Important, don't use as default param in function signature because dict is mutable and shared across function calls
|
||||
metadata = {}
|
||||
if aliases is not None:
|
||||
metadata["aliases"] = aliases
|
||||
if help is not None:
|
||||
metadata["help"] = help
|
||||
|
||||
return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
|
||||
|
||||
|
||||
class HfArgumentParser(ArgumentParser):
|
||||
"""
|
||||
This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
|
||||
@ -84,6 +154,10 @@ class HfArgumentParser(ArgumentParser):
|
||||
"`typing.get_type_hints` method by default"
|
||||
)
|
||||
|
||||
aliases = kwargs.pop("aliases", [])
|
||||
if isinstance(aliases, str):
|
||||
aliases = [aliases]
|
||||
|
||||
origin_type = getattr(field.type, "__origin__", field.type)
|
||||
if origin_type is Union:
|
||||
if str not in field.type.__args__ and (
|
||||
@ -108,9 +182,14 @@ class HfArgumentParser(ArgumentParser):
|
||||
# A variable to store kwargs for a boolean field, if needed
|
||||
# so that we can init a `no_*` complement argument (see below)
|
||||
bool_kwargs = {}
|
||||
if isinstance(field.type, type) and issubclass(field.type, Enum):
|
||||
kwargs["choices"] = [x.value for x in field.type]
|
||||
kwargs["type"] = type(kwargs["choices"][0])
|
||||
if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)):
|
||||
if origin_type is Literal:
|
||||
kwargs["choices"] = field.type.__args__
|
||||
else:
|
||||
kwargs["choices"] = [x.value for x in field.type]
|
||||
|
||||
kwargs["type"] = make_choice_type_function(kwargs["choices"])
|
||||
|
||||
if field.default is not dataclasses.MISSING:
|
||||
kwargs["default"] = field.default
|
||||
else:
|
||||
@ -146,7 +225,7 @@ class HfArgumentParser(ArgumentParser):
|
||||
kwargs["default"] = field.default_factory()
|
||||
else:
|
||||
kwargs["required"] = True
|
||||
parser.add_argument(field_name, **kwargs)
|
||||
parser.add_argument(field_name, *aliases, **kwargs)
|
||||
|
||||
# Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
|
||||
# Order is important for arguments with the same destination!
|
||||
@ -178,7 +257,12 @@ class HfArgumentParser(ArgumentParser):
|
||||
self._parse_dataclass_field(parser, field)
|
||||
|
||||
def parse_args_into_dataclasses(
|
||||
self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None
|
||||
self,
|
||||
args=None,
|
||||
return_remaining_strings=False,
|
||||
look_for_args_file=True,
|
||||
args_filename=None,
|
||||
args_file_flag=None,
|
||||
) -> Tuple[DataClass, ...]:
|
||||
"""
|
||||
Parse command-line args into instances of the specified dataclass types.
|
||||
@ -196,6 +280,9 @@ class HfArgumentParser(ArgumentParser):
|
||||
process, and will append its potential content to the command line args.
|
||||
args_filename:
|
||||
If not None, will uses this file instead of the ".args" file specified in the previous argument.
|
||||
args_file_flag:
|
||||
If not None, will look for a file in the command-line args specified with this flag. The flag can be
|
||||
specified multiple times and precedence is determined by the order (last one wins).
|
||||
|
||||
Returns:
|
||||
Tuple consisting of:
|
||||
@ -205,17 +292,36 @@ class HfArgumentParser(ArgumentParser):
|
||||
after initialization.
|
||||
- The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
|
||||
"""
|
||||
if args_filename or (look_for_args_file and len(sys.argv)):
|
||||
if args_filename:
|
||||
args_file = Path(args_filename)
|
||||
else:
|
||||
args_file = Path(sys.argv[0]).with_suffix(".args")
|
||||
|
||||
if args_file.exists():
|
||||
fargs = args_file.read_text().split()
|
||||
args = fargs + args if args is not None else fargs + sys.argv[1:]
|
||||
# in case of duplicate arguments the first one has precedence
|
||||
# so we append rather than prepend.
|
||||
if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)):
|
||||
args_files = []
|
||||
|
||||
if args_filename:
|
||||
args_files.append(Path(args_filename))
|
||||
elif look_for_args_file and len(sys.argv):
|
||||
args_files.append(Path(sys.argv[0]).with_suffix(".args"))
|
||||
|
||||
# args files specified via command line flag should overwrite default args files so we add them last
|
||||
if args_file_flag:
|
||||
# Create special parser just to extract the args_file_flag values
|
||||
args_file_parser = ArgumentParser()
|
||||
args_file_parser.add_argument(args_file_flag, type=str, action="append")
|
||||
|
||||
# Use only remaining args for further parsing (remove the args_file_flag)
|
||||
cfg, args = args_file_parser.parse_known_args(args=args)
|
||||
cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None)
|
||||
|
||||
if cmd_args_file_paths:
|
||||
args_files.extend([Path(p) for p in cmd_args_file_paths])
|
||||
|
||||
file_args = []
|
||||
for args_file in args_files:
|
||||
if args_file.exists():
|
||||
file_args += args_file.read_text().split()
|
||||
|
||||
# in case of duplicate arguments the last one has precedence
|
||||
# args specified via the command line should overwrite args from files, so we add them last
|
||||
args = file_args + args if args is not None else file_args + sys.argv[1:]
|
||||
namespace, remaining_args = self.parse_known_args(args=args)
|
||||
outputs = []
|
||||
for dtype in self.dataclass_types:
|
||||
|
@ -25,7 +25,15 @@ from typing import List, Optional
|
||||
|
||||
import yaml
|
||||
from transformers import HfArgumentParser, TrainingArguments
|
||||
from transformers.hf_argparser import string_to_bool
|
||||
from transformers.hf_argparser import make_choice_type_function, string_to_bool
|
||||
|
||||
|
||||
try:
|
||||
# For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
# For Python 3.7
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
def list_field(default=None, metadata=None):
|
||||
@ -58,6 +66,12 @@ class BasicEnum(Enum):
|
||||
toto = "toto"
|
||||
|
||||
|
||||
class MixedTypeEnum(Enum):
|
||||
titi = "titi"
|
||||
toto = "toto"
|
||||
fourtytwo = 42
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumExample:
|
||||
foo: BasicEnum = "toto"
|
||||
@ -66,6 +80,14 @@ class EnumExample:
|
||||
self.foo = BasicEnum(self.foo)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MixedTypeEnumExample:
|
||||
foo: MixedTypeEnum = "toto"
|
||||
|
||||
def __post_init__(self):
|
||||
self.foo = MixedTypeEnum(self.foo)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptionalExample:
|
||||
foo: Optional[int] = None
|
||||
@ -111,6 +133,14 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
for x, y in zip(a._actions, b._actions):
|
||||
xx = {k: v for k, v in vars(x).items() if k != "container"}
|
||||
yy = {k: v for k, v in vars(y).items() if k != "container"}
|
||||
|
||||
# Choices with mixed type have custom function as "type"
|
||||
# So we need to compare results directly for equality
|
||||
if xx.get("choices", None) and yy.get("choices", None):
|
||||
for expected_choice in yy["choices"] + xx["choices"]:
|
||||
self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice))
|
||||
del xx["type"], yy["type"]
|
||||
|
||||
self.assertEqual(xx, yy)
|
||||
|
||||
def test_basic(self):
|
||||
@ -163,21 +193,56 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
self.assertEqual(args, Namespace(foo=False, baz=False, opt=False))
|
||||
|
||||
def test_with_enum(self):
|
||||
parser = HfArgumentParser(EnumExample)
|
||||
parser = HfArgumentParser(MixedTypeEnumExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str)
|
||||
expected.add_argument(
|
||||
"--foo",
|
||||
default="toto",
|
||||
choices=["titi", "toto", 42],
|
||||
type=make_choice_type_function(["titi", "toto", 42]),
|
||||
)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args.foo, "toto")
|
||||
enum_ex = parser.parse_args_into_dataclasses([])[0]
|
||||
self.assertEqual(enum_ex.foo, BasicEnum.toto)
|
||||
self.assertEqual(enum_ex.foo, MixedTypeEnum.toto)
|
||||
|
||||
args = parser.parse_args(["--foo", "titi"])
|
||||
self.assertEqual(args.foo, "titi")
|
||||
enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0]
|
||||
self.assertEqual(enum_ex.foo, BasicEnum.titi)
|
||||
self.assertEqual(enum_ex.foo, MixedTypeEnum.titi)
|
||||
|
||||
args = parser.parse_args(["--foo", "42"])
|
||||
self.assertEqual(args.foo, 42)
|
||||
enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0]
|
||||
self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo)
|
||||
|
||||
def test_with_literal(self):
|
||||
@dataclass
|
||||
class LiteralExample:
|
||||
foo: Literal["titi", "toto", 42] = "toto"
|
||||
|
||||
parser = HfArgumentParser(LiteralExample)
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument(
|
||||
"--foo",
|
||||
default="toto",
|
||||
choices=("titi", "toto", 42),
|
||||
type=make_choice_type_function(["titi", "toto", 42]),
|
||||
)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
args = parser.parse_args([])
|
||||
self.assertEqual(args.foo, "toto")
|
||||
|
||||
args = parser.parse_args(["--foo", "titi"])
|
||||
self.assertEqual(args.foo, "titi")
|
||||
|
||||
args = parser.parse_args(["--foo", "42"])
|
||||
self.assertEqual(args.foo, 42)
|
||||
|
||||
def test_with_list(self):
|
||||
parser = HfArgumentParser(ListExample)
|
||||
@ -222,7 +287,12 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--required_list", nargs="+", type=int, required=True)
|
||||
expected.add_argument("--required_str", type=str, required=True)
|
||||
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
|
||||
expected.add_argument(
|
||||
"--required_enum",
|
||||
type=make_choice_type_function(["titi", "toto"]),
|
||||
choices=["titi", "toto"],
|
||||
required=True,
|
||||
)
|
||||
self.argparsersEqual(parser, expected)
|
||||
|
||||
def test_with_string_literal_annotation(self):
|
||||
@ -230,7 +300,12 @@ class HfArgumentParserTest(unittest.TestCase):
|
||||
|
||||
expected = argparse.ArgumentParser()
|
||||
expected.add_argument("--foo", type=int, required=True)
|
||||
expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True)
|
||||
expected.add_argument(
|
||||
"--required_enum",
|
||||
type=make_choice_type_function(["titi", "toto"]),
|
||||
choices=["titi", "toto"],
|
||||
required=True,
|
||||
)
|
||||
expected.add_argument("--opt", type=string_to_bool, default=None)
|
||||
expected.add_argument("--baz", default="toto", type=str, help="help message")
|
||||
expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)
|
||||
|
Loading…
Reference in New Issue
Block a user