mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
add _keep_in_fp32_modules_strict (#39058)
* add _keep_in_fp32_modules_strict * complete test
This commit is contained in:
parent
d973e62fdd
commit
02ecdcfc0f
@ -1937,7 +1937,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
_auto_class = None
|
_auto_class = None
|
||||||
_no_split_modules = None
|
_no_split_modules = None
|
||||||
_skip_keys_device_placement = None
|
_skip_keys_device_placement = None
|
||||||
|
|
||||||
_keep_in_fp32_modules = None
|
_keep_in_fp32_modules = None
|
||||||
|
# the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
|
||||||
|
# to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
|
||||||
|
_keep_in_fp32_modules_strict = None
|
||||||
|
|
||||||
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
|
||||||
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
|
||||||
@ -2049,6 +2053,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
|
# `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
|
||||||
# when a different component (e.g. language_model) is used.
|
# when a different component (e.g. language_model) is used.
|
||||||
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
|
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
|
||||||
|
self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
|
||||||
|
|
||||||
self._no_split_modules = self._no_split_modules or []
|
self._no_split_modules = self._no_split_modules or []
|
||||||
|
|
||||||
@ -2061,7 +2066,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
self._backward_compatibility_gradient_checkpointing()
|
self._backward_compatibility_gradient_checkpointing()
|
||||||
|
|
||||||
# Make sure the modules correctly exist if the flag is active
|
# Make sure the modules correctly exist if the flag is active
|
||||||
if self._keep_in_fp32_modules is not None:
|
if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None:
|
||||||
all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
|
all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
|
||||||
unique_module_names = set()
|
unique_module_names = set()
|
||||||
# Get all unique module names in the module graph, without the prefixes
|
# Get all unique module names in the module graph, without the prefixes
|
||||||
@ -2070,12 +2075,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
[name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
|
[name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
|
||||||
)
|
)
|
||||||
# Check that every module in the keep_in_fp32 list is part of the module graph
|
# Check that every module in the keep_in_fp32 list is part of the module graph
|
||||||
for module in self._keep_in_fp32_modules:
|
if self._keep_in_fp32_modules is not None:
|
||||||
if module not in unique_module_names:
|
for module in self._keep_in_fp32_modules:
|
||||||
raise ValueError(
|
if module not in unique_module_names:
|
||||||
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
|
raise ValueError(
|
||||||
f" {self.__class__.__name__}"
|
f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
|
||||||
)
|
f" {self.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._keep_in_fp32_modules_strict is not None:
|
||||||
|
for module in self._keep_in_fp32_modules_strict:
|
||||||
|
if module not in unique_module_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in"
|
||||||
|
f" {self.__class__.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
|
||||||
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
|
||||||
@ -4757,20 +4771,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
config = model.config
|
config = model.config
|
||||||
|
|
||||||
# Find fp32 modules if needed
|
# Find fp32 modules if needed
|
||||||
keep_in_fp32_regex = None
|
keep_in_fp32_modules = []
|
||||||
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
|
# The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
|
||||||
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
|
# in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
|
||||||
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
|
# step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
|
||||||
# Update: to extend _keep_in_fp32_modules flag feature, it can also be used to force modules that should stay in fp32
|
|
||||||
if model._keep_in_fp32_modules is not None and (
|
if model._keep_in_fp32_modules is not None and (
|
||||||
torch_dtype == torch.float16
|
torch_dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
||||||
or torch_dtype == torch.bfloat16
|
|
||||||
or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
|
|
||||||
):
|
):
|
||||||
|
keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
|
||||||
|
|
||||||
|
if model._keep_in_fp32_modules_strict is not None and (
|
||||||
|
torch_dtype == torch.float16 or torch_dtype == torch.bfloat16
|
||||||
|
):
|
||||||
|
keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
|
||||||
|
|
||||||
|
keep_in_fp32_regex = None
|
||||||
|
if keep_in_fp32_modules:
|
||||||
# We need to match exact layers, so we add either `.` on each side, or start/end of string
|
# We need to match exact layers, so we add either `.` on each side, or start/end of string
|
||||||
keep_in_fp32_regex = re.compile(
|
keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
|
||||||
"|".join([rf"((^|\.){module}($|\.))" for module in model._keep_in_fp32_modules])
|
|
||||||
)
|
|
||||||
|
|
||||||
if hf_quantizer is not None:
|
if hf_quantizer is not None:
|
||||||
hf_quantizer.preprocess_model(
|
hf_quantizer.preprocess_model(
|
||||||
|
@ -1103,7 +1103,7 @@ class KyutaiSpeechToTextForConditionalGeneration(KyutaiSpeechToTextPreTrainedMod
|
|||||||
_tied_weights_keys = ["lm_head.weight"]
|
_tied_weights_keys = ["lm_head.weight"]
|
||||||
_tp_plan = {"lm_head": "colwise_rep"}
|
_tp_plan = {"lm_head": "colwise_rep"}
|
||||||
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
||||||
_keep_in_fp32_modules = ["codec_model"]
|
_keep_in_fp32_modules_strict = ["codec_model"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
@ -252,7 +252,7 @@ class KyutaiSpeechToTextModel(MoshiModel):
|
|||||||
|
|
||||||
|
|
||||||
class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel):
|
class KyutaiSpeechToTextForConditionalGeneration(LlamaForCausalLM, GenerationMixin, PreTrainedModel):
|
||||||
_keep_in_fp32_modules = ["codec_model"]
|
_keep_in_fp32_modules_strict = ["codec_model"]
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
@ -30,6 +30,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
cleanup,
|
cleanup,
|
||||||
|
require_accelerate,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_sdpa,
|
require_torch_sdpa,
|
||||||
@ -615,6 +616,81 @@ class KyutaiSpeechToTextModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||||||
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
|
self._check_similar_generate_outputs(res_eager, res_attn, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@require_accelerate
|
||||||
|
@slow
|
||||||
|
class KyutaiSpeechToTextBf16Test(unittest.TestCase):
|
||||||
|
def test_bf16_fp32_conversion(self):
|
||||||
|
r"""
|
||||||
|
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
|
||||||
|
"""
|
||||||
|
model_checkpoint = "kyutai/stt-2.6b-en-trfs"
|
||||||
|
orig_import = __import__
|
||||||
|
accelerate_mock = unittest.mock.Mock()
|
||||||
|
|
||||||
|
# mock import of accelerate
|
||||||
|
def import_accelerate_mock(name, *args, **kwargs):
|
||||||
|
if name == "accelerate":
|
||||||
|
if accelerate_available:
|
||||||
|
return accelerate_mock
|
||||||
|
else:
|
||||||
|
raise ImportError
|
||||||
|
return orig_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
# Load without using `accelerate`
|
||||||
|
with unittest.mock.patch("builtins.__import__", side_effect=import_accelerate_mock):
|
||||||
|
accelerate_available = False
|
||||||
|
|
||||||
|
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||||
|
model_checkpoint, torch_dtype=torch.float16
|
||||||
|
)
|
||||||
|
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||||
|
self.assertTrue(model.model.dtype == torch.float16)
|
||||||
|
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
|
||||||
|
|
||||||
|
# Load without in bf16
|
||||||
|
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||||
|
model_checkpoint, torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||||
|
self.assertTrue(model.model.dtype == torch.bfloat16)
|
||||||
|
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
|
||||||
|
|
||||||
|
# Load using `accelerate` in bf16
|
||||||
|
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||||
|
model_checkpoint, torch_dtype=torch.bfloat16, device_map="auto"
|
||||||
|
)
|
||||||
|
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||||
|
self.assertTrue(model.model.dtype == torch.bfloat16)
|
||||||
|
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
|
||||||
|
|
||||||
|
# Load using `accelerate` in bf16
|
||||||
|
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||||
|
model_checkpoint,
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||||
|
self.assertTrue(model.model.dtype == torch.bfloat16)
|
||||||
|
self.assertTrue(model.lm_head.weight.data.dtype == torch.bfloat16)
|
||||||
|
|
||||||
|
# Load without using `accelerate`
|
||||||
|
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||||
|
model_checkpoint,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||||
|
self.assertTrue(model.model.dtype == torch.float16)
|
||||||
|
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
|
||||||
|
|
||||||
|
# Load using `accelerate`
|
||||||
|
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(
|
||||||
|
model_checkpoint, torch_dtype=torch.float16, device_map="auto"
|
||||||
|
)
|
||||||
|
self.assertTrue(model.codec_model.dtype == torch.float32)
|
||||||
|
self.assertTrue(model.model.dtype == torch.float16)
|
||||||
|
self.assertTrue(model.lm_head.weight.data.dtype == torch.float16)
|
||||||
|
|
||||||
|
|
||||||
class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase):
|
class KyutaiSpeechToTextForConditionalGenerationIntegrationTests(unittest.TestCase):
|
||||||
_dataset = None
|
_dataset = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user