mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
Add support for torch_dtype
in the run_mlm example (#29776)
feat: add support for torch_dtype Co-authored-by: Jacky Lee <jackylee328@gmail.com>
This commit is contained in:
parent
10d232e88e
commit
ef6e371dba
@ -32,6 +32,7 @@ from typing import Optional
|
||||
|
||||
import datasets
|
||||
import evaluate
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
import transformers
|
||||
@ -133,6 +134,16 @@ class ModelArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
torch_dtype: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
||||
"dtype will be automatically derived from the model's weights."
|
||||
),
|
||||
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||
},
|
||||
)
|
||||
low_cpu_mem_usage: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@ -425,6 +436,11 @@ def main():
|
||||
)
|
||||
|
||||
if model_args.model_name_or_path:
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
model = AutoModelForMaskedLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
||||
@ -433,6 +449,7 @@ def main():
|
||||
revision=model_args.model_revision,
|
||||
token=model_args.token,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
torch_dtype=torch_dtype,
|
||||
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user