mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
[AutoModel] fix torch_dtype=auto
in from_pretrained
(#23379)
* [automodel] fix torch_dtype=auto in from_pretrained * add test * fix logic * Update src/transformers/models/auto/auto_factory.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
8a58809312
commit
bbbc5c15d4
@ -435,19 +435,24 @@ class _BaseAutoModelClass:
|
|||||||
]
|
]
|
||||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||||
if not isinstance(config, PretrainedConfig):
|
if not isinstance(config, PretrainedConfig):
|
||||||
kwargs_copy = copy.deepcopy(kwargs)
|
kwargs_orig = copy.deepcopy(kwargs)
|
||||||
# ensure not to pollute the config object with torch_dtype="auto" - since it's
|
# ensure not to pollute the config object with torch_dtype="auto" - since it's
|
||||||
# meaningless in the context of the config object - torch.dtype values are acceptable
|
# meaningless in the context of the config object - torch.dtype values are acceptable
|
||||||
if kwargs_copy.get("torch_dtype", None) == "auto":
|
if kwargs.get("torch_dtype", None) == "auto":
|
||||||
_ = kwargs_copy.pop("torch_dtype")
|
_ = kwargs.pop("torch_dtype")
|
||||||
|
|
||||||
config, kwargs = AutoConfig.from_pretrained(
|
config, kwargs = AutoConfig.from_pretrained(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
return_unused_kwargs=True,
|
return_unused_kwargs=True,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
**hub_kwargs,
|
**hub_kwargs,
|
||||||
**kwargs_copy,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# if torch_dtype=auto was passed here, ensure to pass it on
|
||||||
|
if kwargs_orig.get("torch_dtype", None) == "auto":
|
||||||
|
kwargs["torch_dtype"] = "auto"
|
||||||
|
|
||||||
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
|
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
|
||||||
if not trust_remote_code:
|
if not trust_remote_code:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -2920,6 +2920,10 @@ class ModelUtilsTest(TestCasePlus):
|
|||||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
self.assertEqual(model.dtype, torch.float16)
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# 3. now retest that AutoModel behaves the same wrt torch_dtype="auto" as T5ForConditionalGeneration
|
||||||
|
model = AutoModel.from_pretrained(model_path, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
# test fp16 save_pretrained, loaded with the explicit fp16
|
# test fp16 save_pretrained, loaded with the explicit fp16
|
||||||
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||||
self.assertEqual(model.dtype, torch.float16)
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
Loading…
Reference in New Issue
Block a user