mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[QoL] Allow dtype str for torch_dtype arg of from_pretrained (#31590)
* Allow dtype str for torch_dtype in from_pretrained * Update docstring * Add tests for str torch_dtype
This commit is contained in:
parent
11138ca013
commit
3a028101e9
@ -2958,6 +2958,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
|
||||
the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
|
||||
|
||||
3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc.
|
||||
|
||||
<Tip>
|
||||
|
||||
For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
|
||||
@ -3661,9 +3663,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"Since the `torch_dtype` attribute can't be found in model's config object, "
|
||||
"will use torch_dtype={torch_dtype} as derived from model's weights"
|
||||
)
|
||||
elif hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}'
|
||||
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
|
||||
)
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
|
@ -445,6 +445,18 @@ class ModelUtilsTest(TestCasePlus):
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_config(config, torch_dtype=torch.int64)
|
||||
|
||||
def test_model_from_config_torch_dtype_str(self):
|
||||
# test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32")
|
||||
self.assertEqual(model.dtype, torch.float32)
|
||||
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16")
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
|
||||
|
||||
def test_model_from_pretrained_torch_dtype(self):
|
||||
# test that the model can be instantiated with dtype of either
|
||||
# 1. explicit from_pretrained's torch_dtype argument
|
||||
|
Loading…
Reference in New Issue
Block a user