[QoL] Allow dtype str for torch_dtype arg of from_pretrained (#31590)

* Allow dtype str for torch_dtype in from_pretrained

* Update docstring

* Add tests for str torch_dtype
This commit is contained in:
Billy Cao 2024-06-27 18:41:49 +08:00 committed by GitHub
parent 11138ca013
commit 3a028101e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 1 deletions

View File

@ -2958,6 +2958,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc.
<Tip>
For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
@ -3661,9 +3663,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"Since the `torch_dtype` attribute can't be found in model's config object, "
"will use torch_dtype={torch_dtype} as derived from model's weights"
)
elif hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
else:
raise ValueError(
f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}'
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)

View File

@ -445,6 +445,18 @@ class ModelUtilsTest(TestCasePlus):
with self.assertRaises(ValueError):
model = AutoModel.from_config(config, torch_dtype=torch.int64)
def test_model_from_config_torch_dtype_str(self):
# test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32")
self.assertEqual(model.dtype, torch.float32)
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16")
self.assertEqual(model.dtype, torch.float16)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
def test_model_from_pretrained_torch_dtype(self):
# test that the model can be instantiated with dtype of either
# 1. explicit from_pretrained's torch_dtype argument