mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-05 22:00:09 +06:00
Add support for torch_dtype
in the run_mlm example (#29776)
feat: add support for torch_dtype Co-authored-by: Jacky Lee <jackylee328@gmail.com>
This commit is contained in:
parent
10d232e88e
commit
ef6e371dba
@ -32,6 +32,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import evaluate
|
import evaluate
|
||||||
|
import torch
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
@ -133,6 +134,16 @@ class ModelArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
torch_dtype: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
||||||
|
"dtype will be automatically derived from the model's weights."
|
||||||
|
),
|
||||||
|
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||||
|
},
|
||||||
|
)
|
||||||
low_cpu_mem_usage: bool = field(
|
low_cpu_mem_usage: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
@ -425,6 +436,11 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_args.model_name_or_path:
|
if model_args.model_name_or_path:
|
||||||
|
torch_dtype = (
|
||||||
|
model_args.torch_dtype
|
||||||
|
if model_args.torch_dtype in ["auto", None]
|
||||||
|
else getattr(torch, model_args.torch_dtype)
|
||||||
|
)
|
||||||
model = AutoModelForMaskedLM.from_pretrained(
|
model = AutoModelForMaskedLM.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||||
@ -433,6 +449,7 @@ def main():
|
|||||||
revision=model_args.model_revision,
|
revision=model_args.model_revision,
|
||||||
token=model_args.token,
|
token=model_args.token,
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
|
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user