diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index f6861212632..55836022546 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -66,9 +66,15 @@ class HfArgumentParser(ArgumentParser): typestring = str(field.type) for prim_type in (int, float, str): for collection in (List,): - if typestring == f"typing.Union[{collection[prim_type]}, NoneType]": + if ( + typestring == f"typing.Union[{collection[prim_type]}, NoneType]" + or typestring == f"typing.Optional[{collection[prim_type]}]" + ): field.type = collection[prim_type] - if typestring == f"typing.Union[{prim_type.__name__}, NoneType]": + if ( + typestring == f"typing.Union[{prim_type.__name__}, NoneType]" + or typestring == f"typing.Optional[{prim_type.__name__}]" + ): field.type = prim_type if isinstance(field.type, type) and issubclass(field.type, Enum):