Skip DeepSpeed ZeRO Stage 3 model initialization when bnb (#34395)

* Skip DeepSpeed ZeRO Stage 3 model initialization when it is intended to be quantized.

* Propagate the quantization state using a context manager

* make fixup
This commit is contained in:
AbdelKarim ELJANDOUBI 2024-11-05 10:06:07 +01:00 committed by GitHub
parent eb811449a2
commit d0b1d8d888
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -136,6 +136,7 @@ logger = logging.get_logger(__name__)
_init_weights = True
_is_quantized = False
def is_fsdp_enabled():
@ -213,6 +214,16 @@ def no_init_weights(_enable=True):
setattr(torch.nn.init, name, init_func)
@contextmanager
def set_quantized_state():
global _is_quantized
_is_quantized = True
try:
yield
finally:
_is_quantized = False
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
@ -1531,7 +1542,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch_dtype=torch_dtype,
)
if is_deepspeed_zero3_enabled():
if is_deepspeed_zero3_enabled() and not _is_quantized:
import deepspeed
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
@ -4086,6 +4097,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
init_contexts.append(init_empty_weights())
if is_deepspeed_zero3_enabled() and is_quantized:
init_contexts.append(set_quantized_state())
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(