mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
add _keep_in_fp32_modules_strict (#39058)
* add _keep_in_fp32_modules_strict * complete test
This commit is contained in:
parent
d973e62fdd
commit
02ecdcfc0f
@ -1937,7 +1937,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
_auto_class = None
|
||||
_no_split_modules = None
|
||||
_skip_keys_device_placement = None
|
||||
|
||||
_keep_in_fp32_modules = None
|
||||
# the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
|
||||
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
|
||||
_keep_in_fp32_modules_strict = None
|
||||
|
||||
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
||||
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
||||
@ -2049,6 +2053,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
|
||||
# when a different component (e.g. language_model) is used.
|
||||
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
|
||||
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
||||
|
||||
self._no_split_modules = self._no_split_modules or []
|
||||
|
||||
@ -2061,7 +2066,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
self._backward_compatibility_gradient_checkpointing()
|
||||
|
||||
# Make sure the modules correctly exist if the flag is active
|
||||
if self._keep_in_fp32_modules is not None:
|
||||
if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None:
|
||||
all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
|
||||
unique_module_names = set()
|
||||
# Get all unique module names in the module graph, without the prefixes
|
||||
@ -2070,12 +2075,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
[name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
|
||||
)
|
||||
# Check that every module in the keep_in_fp32 list is part of the module graph
|
||||
for module in self._keep_in_fp32_modules:
|
||||
if module not in unique_module_names:
|
||||
raise ValueError(
|
||||
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
|
||||
f" {self.__class__.__name__}"
|
||||
)
|
||||
if self._keep_in_fp32_modules is not None:
|
||||
for module in self._keep_in_fp32_modules:
|
||||
if module not in unique_module_names:
|
||||
raise ValueError(
|
||||
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
|
||||
f" {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
if self._keep_in_fp32_modules_strict is not None:
|
||||
for module in self._keep_in_fp32_modules_strict:
|
||||
if module not in unique_module_names:
|
||||
raise ValueError(
|
||||
f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in"
|
||||
f" {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
||||
@ -4757,20 +4771,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
||||
config = model.config
|
||||
|
||||
# Find fp32 modules if needed
|
||||
keep_in_fp32_regex = None
|
||||
keep_in_fp32_modules = []
|
||||
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
|
||||
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
|
||||
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
|
||||
# Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32
|
||||
if model._keep_in_fp32_modules is not None and (
|
||||
torch_dtype == torch.float16
|
||||
or torch_dtype == torch.bfloat16
|
||||
or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||
):
|
||||
keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
|
||||
|
||||
if model._keep_in_fp32_modules_strict is not None and (
|
||||
torch_dtype == torch.float16 or torch_dtype == torch.bfloat16
|
||||
):
|
||||
keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
|
||||
|
||||
keep_in_fp32_regex = None
|
||||
if keep_in_fp32_modules:
|
||||
# We need to match exact layers, so we add either `.` on each side, or start/end of string
|
||||
keep_in_fp32_regex = re.compile(
|
||||
"|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules])
|
||||
)
|
||||
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
|
@ -1103,7 +1103,7 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tp_plan = {"lm_head": "colwise_rep"}
|
||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||
_keep_in_fp32_modules = ["codec_model"]
|
||||
_keep_in_fp32_modules_strict = ["codec_model"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -252,7 +252,7 @@ class KyutaiSpeechToTextModel(MoshiModel):
|
||||
|
||||
|
||||
class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel):
|
||||
_keep_in_fp32_modules = ["codec_model"]
|
||||
_keep_in_fp32_modules_strict = ["codec_model"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -30,6 +30,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
cleanup,
|
||||
require_accelerate,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_sdpa,
|
||||
@ -615,6 +616,81 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
||||
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_accelerate
|
||||
@slow
|
||||
class KyutaiSpeechToTextBf16Test(unittest.TestCase):
|
||||
def test_bf16_fp32_conversion(self):
|
||||
r"""
|
||||
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
|
||||
"""
|
||||
model_checkpoint = "kyutai/stt-2.6b-en-trfs"
|
||||
orig_import = __import__
|
||||
accelerate_mock = unittest.mock.Mock()
|
||||
|
||||
# mock import of accelerate
|
||||
def import_accelerate_mock(name, *args, **kwargs):
|
||||
if name == "accelerate":
|
||||
if accelerate_available:
|
||||
return accelerate_mock
|
||||
else:
|
||||
raise ImportError
|
||||
return orig_import(name, *args, **kwargs)
|
||||
|
||||
# Load without using `accelerate`
|
||||
with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock):
|
||||
accelerate_available = False
|
||||
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint, torch_dtype=torch.float16
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.float16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
|
||||
|
||||
# Load without in bf16
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint, torch_dtype=torch.bfloat16
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
|
||||
|
||||
# Load using `accelerate` in bf16
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
|
||||
|
||||
# Load using `accelerate` in bf16
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.bfloat16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
|
||||
|
||||
# Load without using `accelerate`
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.float16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
|
||||
|
||||
# Load using `accelerate`
|
||||
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||
model_checkpoint, torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||
self.assertTrue(model.model.dtype == torch.float16)
|
||||
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
|
||||
|
||||
|
||||
class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase):
|
||||
_dataset = None
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user