add _keep_in_fp32_modules_strict (#39058)

* add _keep_in_fp32_modules_strict

* complete test
This commit is contained in:
eustlb 2025-06-26 15:55:28 +02:00 committed by GitHub
parent d973e62fdd
commit 02ecdcfc0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 111 additions and 17 deletions

View File

@ -1937,7 +1937,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
_keep_in_fp32_modules = None
# the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
_keep_in_fp32_modules_strict = None
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
@ -2049,6 +2053,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
# when a different component (e.g. language_model) is used.
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
self._no_split_modules = self._no_split_modules or []
@ -2061,7 +2066,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
self._backward_compatibility_gradient_checkpointing()
# Make sure the modules correctly exist if the flag is active
if self._keep_in_fp32_modules is not None:
if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None:
all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
unique_module_names = set()
# Get all unique module names in the module graph, without the prefixes
@ -2070,12 +2075,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
[name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
)
# Check that every module in the keep_in_fp32 list is part of the module graph
for module in self._keep_in_fp32_modules:
if module not in unique_module_names:
raise ValueError(
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
f" {self.__class__.__name__}"
)
if self._keep_in_fp32_modules is not None:
for module in self._keep_in_fp32_modules:
if module not in unique_module_names:
raise ValueError(
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
f" {self.__class__.__name__}"
)
if self._keep_in_fp32_modules_strict is not None:
for module in self._keep_in_fp32_modules_strict:
if module not in unique_module_names:
raise ValueError(
f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in"
f" {self.__class__.__name__}"
)
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
@ -4757,20 +4771,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
config = model.config
# Find fp32 modules if needed
keep_in_fp32_regex = None
keep_in_fp32_modules = []
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
# Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32
if model._keep_in_fp32_modules is not None and (
torch_dtype == torch.float16
or torch_dtype == torch.bfloat16
or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
):
keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
if model._keep_in_fp32_modules_strict is not None and (
torch_dtype == torch.float16 or torch_dtype == torch.bfloat16
):
keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
keep_in_fp32_regex = None
if keep_in_fp32_modules:
# We need to match exact layers, so we add either `.` on each side, or start/end of string
keep_in_fp32_regex = re.compile(
"|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules])
)
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
if hf_quantizer is not None:
hf_quantizer.preprocess_model(

View File

@ -1103,7 +1103,7 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
_keep_in_fp32_modules = ["codec_model"]
_keep_in_fp32_modules_strict = ["codec_model"]
def __init__(self, config):
super().__init__(config)

View File

@ -252,7 +252,7 @@ class KyutaiSpeechToTextModel(MoshiModel):
class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel):
_keep_in_fp32_modules = ["codec_model"]
_keep_in_fp32_modules_strict = ["codec_model"]
def __init__(self, config):
super().__init__(config)

View File

@ -30,6 +30,7 @@ from transformers import (
)
from transformers.testing_utils import (
cleanup,
require_accelerate,
require_torch,
require_torch_accelerator,
require_torch_sdpa,
@ -615,6 +616,81 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
@require_torch
@require_accelerate
@slow
class KyutaiSpeechToTextBf16Test(unittest.TestCase):
def test_bf16_fp32_conversion(self):
r"""
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
"""
model_checkpoint = "kyutai/stt-2.6b-en-trfs"
orig_import = __import__
accelerate_mock = unittest.mock.Mock()
# mock import of accelerate
def import_accelerate_mock(name, *args, **kwargs):
if name == "accelerate":
if accelerate_available:
return accelerate_mock
else:
raise ImportError
return orig_import(name, *args, **kwargs)
# Load without using `accelerate`
with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock):
accelerate_available = False
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
model_checkpoint, torch_dtype=torch.float16
)
self.assertTrue(model.codec_model.dtype == torch.float32)
self.assertTrue(model.model.dtype == torch.float16)
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
# Load without in bf16
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
model_checkpoint, torch_dtype=torch.bfloat16
)
self.assertTrue(model.codec_model.dtype == torch.float32)
self.assertTrue(model.model.dtype == torch.bfloat16)
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
# Load using `accelerate` in bf16
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
model_checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
)
self.assertTrue(model.codec_model.dtype == torch.float32)
self.assertTrue(model.model.dtype == torch.bfloat16)
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
# Load using `accelerate` in bf16
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
model_checkpoint,
torch_dtype=torch.bfloat16,
)
self.assertTrue(model.codec_model.dtype == torch.float32)
self.assertTrue(model.model.dtype == torch.bfloat16)
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
# Load without using `accelerate`
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
model_checkpoint,
torch_dtype=torch.float16,
)
self.assertTrue(model.codec_model.dtype == torch.float32)
self.assertTrue(model.model.dtype == torch.float16)
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
# Load using `accelerate`
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
model_checkpoint, torch_dtype=torch.float16, device_map="auto"
)
self.assertTrue(model.codec_model.dtype == torch.float32)
self.assertTrue(model.model.dtype == torch.float16)
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase):
_dataset = None