mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Skip DeepSpeed ZeRO Stage 3 model initialization when bnb (#34395)
* Skip DeepSpeed ZeRO Stage 3 model initialization when it is intended to be quantized. * Propagate the quantization state using a context manager * make fixup
This commit is contained in:
parent
eb811449a2
commit
d0b1d8d888
@ -136,6 +136,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
_init_weights = True
|
||||
_is_quantized = False
|
||||
|
||||
|
||||
def is_fsdp_enabled():
|
||||
@ -213,6 +214,16 @@ def no_init_weights(_enable=True):
|
||||
setattr(torch.nn.init, name, init_func)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_quantized_state():
|
||||
global _is_quantized
|
||||
_is_quantized = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_is_quantized = False
|
||||
|
||||
|
||||
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
@ -1531,7 +1542,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
if is_deepspeed_zero3_enabled():
|
||||
if is_deepspeed_zero3_enabled() and not _is_quantized:
|
||||
import deepspeed
|
||||
|
||||
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
|
||||
@ -4086,6 +4097,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
init_contexts.append(init_empty_weights())
|
||||
|
||||
if is_deepspeed_zero3_enabled() and is_quantized:
|
||||
init_contexts.append(set_quantized_state())
|
||||
|
||||
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
|
||||
if not getattr(config, "_attn_implementation_autoset", False):
|
||||
config = cls._autoset_attn_implementation(
|
||||
|
Loading…
Reference in New Issue
Block a user