mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add keep_in_fp32_modules
support (#20683)
* add `keep_in_fp32_modules` support * pass it as class attribute * few modifs - make tests `slow` - fix logic * better logic * fix failing test * `bfloat16` support * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix * simplify tests * simplify tests * fix test * modify message * more checks * fix failing tests * add more conditions - add `is_accelerate_available` - fixes pipleine tests that failed * add suggestions * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix failing `bnb` test * add last safety checker Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
d4bf9ee1ff
commit
1af4bee896
@ -562,6 +562,7 @@ def _load_state_dict_into_meta_model(
|
||||
dtype=None,
|
||||
load_in_8bit=False,
|
||||
is_safetensors=False,
|
||||
keep_in_fp32_modules=None,
|
||||
):
|
||||
"""
|
||||
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
||||
@ -612,7 +613,14 @@ def _load_state_dict_into_meta_model(
|
||||
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
|
||||
# in int/uint/bool and not cast them.
|
||||
if dtype is not None and torch.is_floating_point(param):
|
||||
param = param.to(dtype)
|
||||
if (
|
||||
keep_in_fp32_modules is not None
|
||||
and any(module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules)
|
||||
and dtype == torch.float16
|
||||
):
|
||||
param = param.to(torch.float32)
|
||||
else:
|
||||
param = param.to(dtype)
|
||||
|
||||
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
|
||||
if dtype is None:
|
||||
@ -974,6 +982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
main_input_name = "input_ids"
|
||||
_auto_class = None
|
||||
_no_split_modules = None
|
||||
_keep_in_fp32_modules = None
|
||||
|
||||
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
||||
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
||||
@ -2071,6 +2080,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Load model
|
||||
loading_info = None
|
||||
|
||||
# Keep in fp32 modules
|
||||
keep_in_fp32_modules = None
|
||||
use_keep_in_fp32_modules = False
|
||||
|
||||
if pretrained_model_name_or_path is not None:
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
is_local = os.path.isdir(pretrained_model_name_or_path)
|
||||
@ -2269,6 +2282,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
|
||||
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||
dtype_orig = None
|
||||
|
||||
if torch_dtype is not None:
|
||||
if isinstance(torch_dtype, str):
|
||||
if torch_dtype == "auto":
|
||||
@ -2286,11 +2300,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (
|
||||
(cls._keep_in_fp32_modules is not None) and is_accelerate_available() and torch_dtype == torch.float16
|
||||
)
|
||||
if (
|
||||
(cls._keep_in_fp32_modules is not None)
|
||||
and not is_accelerate_available()
|
||||
and torch_dtype == torch.float16
|
||||
):
|
||||
logger.warning(
|
||||
"For stability purposes, it is recommended to have accelerate installed when using this model in"
|
||||
" torch.float16, please install it with `pip install accelerate`"
|
||||
)
|
||||
|
||||
if is_sharded:
|
||||
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
|
||||
else:
|
||||
loaded_state_dict_keys = [k for k in state_dict.keys()]
|
||||
if low_cpu_mem_usage:
|
||||
if low_cpu_mem_usage or use_keep_in_fp32_modules:
|
||||
state_dict = None
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
@ -2309,6 +2337,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
with ContextManagers(init_contexts):
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
# Check first if we are `from_pt`
|
||||
if use_keep_in_fp32_modules:
|
||||
low_cpu_mem_usage = True
|
||||
keep_in_fp32_modules = model._keep_in_fp32_modules
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if load_in_8bit:
|
||||
from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear
|
||||
|
||||
@ -2319,6 +2354,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
else:
|
||||
modules_to_not_convert = load_in_8bit_skip_modules
|
||||
|
||||
if not isinstance(modules_to_not_convert, list):
|
||||
modules_to_not_convert = [modules_to_not_convert]
|
||||
|
||||
modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
|
||||
model = replace_8bit_linear(
|
||||
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
|
||||
)
|
||||
@ -2425,6 +2466,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
load_in_8bit=load_in_8bit,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
model.is_loaded_in_8bit = load_in_8bit
|
||||
@ -2468,6 +2510,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
offload_state_dict=None,
|
||||
dtype=None,
|
||||
load_in_8bit=False,
|
||||
keep_in_fp32_modules=None,
|
||||
):
|
||||
is_safetensors = False
|
||||
if load_in_8bit:
|
||||
@ -2544,11 +2587,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if key.startswith(prefix):
|
||||
key = ".".join(key.split(".")[1:])
|
||||
param = model_state_dict[key]
|
||||
|
||||
# upcast in fp32 if any
|
||||
target_dtype = dtype
|
||||
if (
|
||||
keep_in_fp32_modules is not None
|
||||
and dtype == torch.float16
|
||||
and any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules)
|
||||
):
|
||||
target_dtype = torch.float32
|
||||
|
||||
if param.device == torch.device("meta"):
|
||||
if not load_in_8bit:
|
||||
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
|
||||
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype))
|
||||
else:
|
||||
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
|
||||
set_module_8bit_tensor_to_device(
|
||||
model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)
|
||||
)
|
||||
|
||||
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
|
||||
if _fast_init:
|
||||
@ -2558,6 +2613,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
for module in uninitialized_modules:
|
||||
model._init_weights(module)
|
||||
|
||||
# Set some modules to fp32 if any
|
||||
if keep_in_fp32_modules is not None:
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||
param = param.to(torch.float32)
|
||||
|
||||
# Make sure we are able to load base models as well as derived models (with heads)
|
||||
start_prefix = ""
|
||||
model_to_load = model
|
||||
@ -2693,6 +2754,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
dtype=dtype,
|
||||
load_in_8bit=load_in_8bit,
|
||||
is_safetensors=is_safetensors,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
error_msgs += new_error_msgs
|
||||
else:
|
||||
|
@ -757,6 +757,7 @@ class T5PreTrainedModel(PreTrainedModel):
|
||||
is_parallelizable = True
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["T5Block"]
|
||||
_keep_in_fp32_modules = ["wo"]
|
||||
|
||||
@property
|
||||
def dummy_inputs(self):
|
||||
|
@ -150,7 +150,7 @@ def get_keys_to_not_convert(model):
|
||||
|
||||
# Ignore this for base models (BertModel, GPT2Model, etc.)
|
||||
if (not has_tied_params) and is_base_model:
|
||||
return ""
|
||||
return []
|
||||
|
||||
# otherwise they have an attached head
|
||||
list_modules = list(model.named_parameters())
|
||||
|
@ -155,6 +155,13 @@ class MixedInt8Test(BaseMixedInt8Test):
|
||||
# Check this does not throw an error
|
||||
_ = self.model_fp16.float()
|
||||
|
||||
def test_fp32_int8_conversion(self):
|
||||
r"""
|
||||
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly.
|
||||
"""
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto")
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
|
||||
|
||||
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
|
||||
def setUp(self):
|
||||
|
@ -19,7 +19,14 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import T5Config, is_torch_available
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
@ -820,6 +827,50 @@ def use_task_specific_params(model, task):
|
||||
model.config.update(model.config.task_specific_params[task])
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_accelerate
|
||||
@require_tokenizers
|
||||
@slow
|
||||
class T5ModelFp16Tests(unittest.TestCase):
|
||||
def test_fp16_fp32_conversion(self):
|
||||
r"""
|
||||
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
|
||||
"""
|
||||
# Load without using `accelerate`
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||
|
||||
# Load without in bf16
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
||||
|
||||
# Load using `accelerate` in bf16
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
||||
|
||||
# Load using `accelerate` in bf16
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
"t5-small", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
|
||||
)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
|
||||
|
||||
# Load without using `accelerate`
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
"t5-small", torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||
|
||||
# Load using `accelerate`
|
||||
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16, device_map="auto")
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
|
||||
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
|
Loading…
Reference in New Issue
Block a user