Add keep_in_fp32_modules support (#20683)

* add `keep_in_fp32_modules` support

* pass it as class attribute

* few modifs

- make tests `slow`
- fix logic

* better logic

* fix failing test

* `bfloat16` support

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix

* simplify tests

* simplify tests

* fix test

* modify message

* more checks

* fix failing tests

* add more conditions

- add `is_accelerate_available`
- fixes pipleine tests that failed

* add suggestions

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix failing `bnb` test

* add last safety checker

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2022-12-13 11:59:57 +01:00 committed by GitHub
parent d4bf9ee1ff
commit 1af4bee896
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 127 additions and 6 deletions

View File

@ -562,6 +562,7 @@ def _load_state_dict_into_meta_model(
dtype=None,
load_in_8bit=False,
is_safetensors=False,
keep_in_fp32_modules=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
@ -612,6 +613,13 @@ def _load_state_dict_into_meta_model(
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
if (
keep_in_fp32_modules is not None
and any(module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules)
and dtype == torch.float16
):
param = param.to(torch.float32)
else:
param = param.to(dtype)
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
@ -974,6 +982,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
main_input_name = "input_ids"
_auto_class = None
_no_split_modules = None
_keep_in_fp32_modules = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
@ -2071,6 +2080,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Load model
loading_info = None
# Keep in fp32 modules
keep_in_fp32_modules = None
use_keep_in_fp32_modules = False
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
@ -2269,6 +2282,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
dtype_orig = None
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if torch_dtype == "auto":
@ -2286,11 +2300,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (
(cls._keep_in_fp32_modules is not None) and is_accelerate_available() and torch_dtype == torch.float16
)
if (
(cls._keep_in_fp32_modules is not None)
and not is_accelerate_available()
and torch_dtype == torch.float16
):
logger.warning(
"For stability purposes, it is recommended to have accelerate installed when using this model in"
" torch.float16, please install it with `pip install accelerate`"
)
if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = [k for k in state_dict.keys()]
if low_cpu_mem_usage:
if low_cpu_mem_usage or use_keep_in_fp32_modules:
state_dict = None
config.name_or_path = pretrained_model_name_or_path
@ -2309,6 +2337,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)
# Check first if we are `from_pt`
if use_keep_in_fp32_modules:
low_cpu_mem_usage = True
keep_in_fp32_modules = model._keep_in_fp32_modules
else:
keep_in_fp32_modules = []
if load_in_8bit:
from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear
@ -2319,6 +2354,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = load_in_8bit_skip_modules
if not isinstance(modules_to_not_convert, list):
modules_to_not_convert = [modules_to_not_convert]
modules_to_not_convert.extend(keep_in_fp32_modules)
model = replace_8bit_linear(
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
)
@ -2425,6 +2466,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
load_in_8bit=load_in_8bit,
keep_in_fp32_modules=keep_in_fp32_modules,
)
model.is_loaded_in_8bit = load_in_8bit
@ -2468,6 +2510,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
offload_state_dict=None,
dtype=None,
load_in_8bit=False,
keep_in_fp32_modules=None,
):
is_safetensors = False
if load_in_8bit:
@ -2544,11 +2587,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if key.startswith(prefix):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]
# upcast in fp32 if any
target_dtype = dtype
if (
keep_in_fp32_modules is not None
and dtype == torch.float16
and any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules)
):
target_dtype = torch.float32
if param.device == torch.device("meta"):
if not load_in_8bit:
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype))
else:
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
set_module_8bit_tensor_to_device(
model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)
)
# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
@ -2558,6 +2613,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for module in uninitialized_modules:
model._init_weights(module)
# Set some modules to fp32 if any
if keep_in_fp32_modules is not None:
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
param = param.to(torch.float32)
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
@ -2693,6 +2754,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
dtype=dtype,
load_in_8bit=load_in_8bit,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
else:

View File

@ -757,6 +757,7 @@ class T5PreTrainedModel(PreTrainedModel):
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["T5Block"]
_keep_in_fp32_modules = ["wo"]
@property
def dummy_inputs(self):

View File

@ -150,7 +150,7 @@ def get_keys_to_not_convert(model):
# Ignore this for base models (BertModel, GPT2Model, etc.)
if (not has_tied_params) and is_base_model:
return ""
return []
# otherwise they have an attached head
list_modules = list(model.named_parameters())

View File

@ -155,6 +155,13 @@ class MixedInt8Test(BaseMixedInt8Test):
# Check this does not throw an error
_ = self.model_fp16.float()
def test_fp32_int8_conversion(self):
r"""
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly.
"""
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):

View File

@ -19,7 +19,14 @@ import tempfile
import unittest
from transformers import T5Config, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.testing_utils import (
require_accelerate,
require_sentencepiece,
require_tokenizers,
require_torch,
slow,
torch_device,
)
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
@ -820,6 +827,50 @@ def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])
@require_torch
@require_accelerate
@require_tokenizers
@slow
class T5ModelFp16Tests(unittest.TestCase):
def test_fp16_fp32_conversion(self):
r"""
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
"""
# Load without using `accelerate`
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
# Load without in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
# Load using `accelerate` in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
# Load using `accelerate` in bf16
model = T5ForConditionalGeneration.from_pretrained(
"t5-small", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)
# Load without using `accelerate`
model = T5ForConditionalGeneration.from_pretrained(
"t5-small", torch_dtype=torch.float16, low_cpu_mem_usage=True
)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
# Load using `accelerate`
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)
@require_torch
@require_sentencepiece
@require_tokenizers