From 1e3f17b5abcc6452ec6c9a742375975874978f3e Mon Sep 17 00:00:00 2001 From: Konstantin Dobler Date: Mon, 21 Nov 2022 18:33:37 +0100 Subject: [PATCH] Enhance HfArgumentParser functionality and ease of use (#20323) * Enhance HfArgumentParser * Fix type hints for older python versions * Fix and add tests (+formatting) * Add changes * doc-builder formatting * Remove unused import "Call" --- src/transformers/hf_argparser.py | 138 +++++++++++++++++++++++++++---- tests/utils/test_hf_argparser.py | 89 ++++++++++++++++++-- 2 files changed, 204 insertions(+), 23 deletions(-) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 06a10ff5a05..8dc3e3d6fd3 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -20,11 +20,19 @@ from copy import copy from enum import Enum from inspect import isclass from pathlib import Path -from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints +from typing import Any, Callable, Dict, Iterable, List, NewType, Optional, Tuple, Union, get_type_hints import yaml +try: + # For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/ + from typing import Literal +except ImportError: + # For Python 3.7 + from typing_extensions import Literal + + DataClass = NewType("DataClass", Any) DataClassType = NewType("DataClassType", Any) @@ -43,6 +51,68 @@ def string_to_bool(v): ) +def make_choice_type_function(choices: list) -> Callable[[str], Any]: + """ + Creates a mapping function from each choices string representation to the actual value. Used to support multiple + value types for a single argument. + + Args: + choices (list): List of choices. + + Returns: + Callable[[str], Any]: Mapping function from string representation to actual value for each choice. + """ + str_to_choice = {str(choice): choice for choice in choices} + return lambda arg: str_to_choice.get(arg, arg) + + +def HfArg( + *, + aliases: Union[str, List[str]] = None, + help: str = None, + default: Any = dataclasses.MISSING, + default_factory: Callable[[], Any] = dataclasses.MISSING, + metadata: dict = None, + **kwargs, +) -> dataclasses.Field: + """Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`. + + Example comparing the use of `HfArg` and `dataclasses.field`: + ``` + @dataclass + class Args: + regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"}) + hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!") + ``` + + Args: + aliases (Union[str, List[str]], optional): + Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`. + Defaults to None. + help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None. + default (Any, optional): + Default value for the argument. If not default or default_factory is specified, the argument is required. + Defaults to dataclasses.MISSING. + default_factory (Callable[[], Any], optional): + The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide + default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`. + Defaults to dataclasses.MISSING. + metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None. + + Returns: + Field: A `dataclasses.Field` with the desired properties. + """ + if metadata is None: + # Important, don't use as default param in function signature because dict is mutable and shared across function calls + metadata = {} + if aliases is not None: + metadata["aliases"] = aliases + if help is not None: + metadata["help"] = help + + return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs) + + class HfArgumentParser(ArgumentParser): """ This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments. @@ -84,6 +154,10 @@ class HfArgumentParser(ArgumentParser): "`typing.get_type_hints` method by default" ) + aliases = kwargs.pop("aliases", []) + if isinstance(aliases, str): + aliases = [aliases] + origin_type = getattr(field.type, "__origin__", field.type) if origin_type is Union: if str not in field.type.__args__ and ( @@ -108,9 +182,14 @@ class HfArgumentParser(ArgumentParser): # A variable to store kwargs for a boolean field, if needed # so that we can init a `no_*` complement argument (see below) bool_kwargs = {} - if isinstance(field.type, type) and issubclass(field.type, Enum): - kwargs["choices"] = [x.value for x in field.type] - kwargs["type"] = type(kwargs["choices"][0]) + if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)): + if origin_type is Literal: + kwargs["choices"] = field.type.__args__ + else: + kwargs["choices"] = [x.value for x in field.type] + + kwargs["type"] = make_choice_type_function(kwargs["choices"]) + if field.default is not dataclasses.MISSING: kwargs["default"] = field.default else: @@ -146,7 +225,7 @@ class HfArgumentParser(ArgumentParser): kwargs["default"] = field.default_factory() else: kwargs["required"] = True - parser.add_argument(field_name, **kwargs) + parser.add_argument(field_name, *aliases, **kwargs) # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added. # Order is important for arguments with the same destination! @@ -178,7 +257,12 @@ class HfArgumentParser(ArgumentParser): self._parse_dataclass_field(parser, field) def parse_args_into_dataclasses( - self, args=None, return_remaining_strings=False, look_for_args_file=True, args_filename=None + self, + args=None, + return_remaining_strings=False, + look_for_args_file=True, + args_filename=None, + args_file_flag=None, ) -> Tuple[DataClass, ...]: """ Parse command-line args into instances of the specified dataclass types. @@ -196,6 +280,9 @@ class HfArgumentParser(ArgumentParser): process, and will append its potential content to the command line args. args_filename: If not None, will uses this file instead of the ".args" file specified in the previous argument. + args_file_flag: + If not None, will look for a file in the command-line args specified with this flag. The flag can be + specified multiple times and precedence is determined by the order (last one wins). Returns: Tuple consisting of: @@ -205,17 +292,36 @@ class HfArgumentParser(ArgumentParser): after initialization. - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args) """ - if args_filename or (look_for_args_file and len(sys.argv)): - if args_filename: - args_file = Path(args_filename) - else: - args_file = Path(sys.argv[0]).with_suffix(".args") - if args_file.exists(): - fargs = args_file.read_text().split() - args = fargs + args if args is not None else fargs + sys.argv[1:] - # in case of duplicate arguments the first one has precedence - # so we append rather than prepend. + if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)): + args_files = [] + + if args_filename: + args_files.append(Path(args_filename)) + elif look_for_args_file and len(sys.argv): + args_files.append(Path(sys.argv[0]).with_suffix(".args")) + + # args files specified via command line flag should overwrite default args files so we add them last + if args_file_flag: + # Create special parser just to extract the args_file_flag values + args_file_parser = ArgumentParser() + args_file_parser.add_argument(args_file_flag, type=str, action="append") + + # Use only remaining args for further parsing (remove the args_file_flag) + cfg, args = args_file_parser.parse_known_args(args=args) + cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None) + + if cmd_args_file_paths: + args_files.extend([Path(p) for p in cmd_args_file_paths]) + + file_args = [] + for args_file in args_files: + if args_file.exists(): + file_args += args_file.read_text().split() + + # in case of duplicate arguments the last one has precedence + # args specified via the command line should overwrite args from files, so we add them last + args = file_args + args if args is not None else file_args + sys.argv[1:] namespace, remaining_args = self.parse_known_args(args=args) outputs = [] for dtype in self.dataclass_types: diff --git a/tests/utils/test_hf_argparser.py b/tests/utils/test_hf_argparser.py index 9dfff75948c..da824f47438 100644 --- a/tests/utils/test_hf_argparser.py +++ b/tests/utils/test_hf_argparser.py @@ -25,7 +25,15 @@ from typing import List, Optional import yaml from transformers import HfArgumentParser, TrainingArguments -from transformers.hf_argparser import string_to_bool +from transformers.hf_argparser import make_choice_type_function, string_to_bool + + +try: + # For Python versions <3.8, Literal is not in typing: https://peps.python.org/pep-0586/ + from typing import Literal +except ImportError: + # For Python 3.7 + from typing_extensions import Literal def list_field(default=None, metadata=None): @@ -58,6 +66,12 @@ class BasicEnum(Enum): toto = "toto" +class MixedTypeEnum(Enum): + titi = "titi" + toto = "toto" + fourtytwo = 42 + + @dataclass class EnumExample: foo: BasicEnum = "toto" @@ -66,6 +80,14 @@ class EnumExample: self.foo = BasicEnum(self.foo) +@dataclass +class MixedTypeEnumExample: + foo: MixedTypeEnum = "toto" + + def __post_init__(self): + self.foo = MixedTypeEnum(self.foo) + + @dataclass class OptionalExample: foo: Optional[int] = None @@ -111,6 +133,14 @@ class HfArgumentParserTest(unittest.TestCase): for x, y in zip(a._actions, b._actions): xx = {k: v for k, v in vars(x).items() if k != "container"} yy = {k: v for k, v in vars(y).items() if k != "container"} + + # Choices with mixed type have custom function as "type" + # So we need to compare results directly for equality + if xx.get("choices", None) and yy.get("choices", None): + for expected_choice in yy["choices"] + xx["choices"]: + self.assertEqual(xx["type"](expected_choice), yy["type"](expected_choice)) + del xx["type"], yy["type"] + self.assertEqual(xx, yy) def test_basic(self): @@ -163,21 +193,56 @@ class HfArgumentParserTest(unittest.TestCase): self.assertEqual(args, Namespace(foo=False, baz=False, opt=False)) def test_with_enum(self): - parser = HfArgumentParser(EnumExample) + parser = HfArgumentParser(MixedTypeEnumExample) expected = argparse.ArgumentParser() - expected.add_argument("--foo", default="toto", choices=["titi", "toto"], type=str) + expected.add_argument( + "--foo", + default="toto", + choices=["titi", "toto", 42], + type=make_choice_type_function(["titi", "toto", 42]), + ) self.argparsersEqual(parser, expected) args = parser.parse_args([]) self.assertEqual(args.foo, "toto") enum_ex = parser.parse_args_into_dataclasses([])[0] - self.assertEqual(enum_ex.foo, BasicEnum.toto) + self.assertEqual(enum_ex.foo, MixedTypeEnum.toto) args = parser.parse_args(["--foo", "titi"]) self.assertEqual(args.foo, "titi") enum_ex = parser.parse_args_into_dataclasses(["--foo", "titi"])[0] - self.assertEqual(enum_ex.foo, BasicEnum.titi) + self.assertEqual(enum_ex.foo, MixedTypeEnum.titi) + + args = parser.parse_args(["--foo", "42"]) + self.assertEqual(args.foo, 42) + enum_ex = parser.parse_args_into_dataclasses(["--foo", "42"])[0] + self.assertEqual(enum_ex.foo, MixedTypeEnum.fourtytwo) + + def test_with_literal(self): + @dataclass + class LiteralExample: + foo: Literal["titi", "toto", 42] = "toto" + + parser = HfArgumentParser(LiteralExample) + + expected = argparse.ArgumentParser() + expected.add_argument( + "--foo", + default="toto", + choices=("titi", "toto", 42), + type=make_choice_type_function(["titi", "toto", 42]), + ) + self.argparsersEqual(parser, expected) + + args = parser.parse_args([]) + self.assertEqual(args.foo, "toto") + + args = parser.parse_args(["--foo", "titi"]) + self.assertEqual(args.foo, "titi") + + args = parser.parse_args(["--foo", "42"]) + self.assertEqual(args.foo, 42) def test_with_list(self): parser = HfArgumentParser(ListExample) @@ -222,7 +287,12 @@ class HfArgumentParserTest(unittest.TestCase): expected = argparse.ArgumentParser() expected.add_argument("--required_list", nargs="+", type=int, required=True) expected.add_argument("--required_str", type=str, required=True) - expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True) + expected.add_argument( + "--required_enum", + type=make_choice_type_function(["titi", "toto"]), + choices=["titi", "toto"], + required=True, + ) self.argparsersEqual(parser, expected) def test_with_string_literal_annotation(self): @@ -230,7 +300,12 @@ class HfArgumentParserTest(unittest.TestCase): expected = argparse.ArgumentParser() expected.add_argument("--foo", type=int, required=True) - expected.add_argument("--required_enum", type=str, choices=["titi", "toto"], required=True) + expected.add_argument( + "--required_enum", + type=make_choice_type_function(["titi", "toto"]), + choices=["titi", "toto"], + required=True, + ) expected.add_argument("--opt", type=string_to_bool, default=None) expected.add_argument("--baz", default="toto", type=str, help="help message") expected.add_argument("--foo_str", nargs="+", default=["Hallo", "Bonjour", "Hello"], type=str)