make sure to disable gradients for integer tensor (#32943)

This commit is contained in:
Wing Lian 2024-11-18 10:49:37 -05:00 committed by GitHub
parent 1c471fc307
commit 36759f3312
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -927,7 +927,10 @@ def _load_state_dict_into_meta_model(
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
value = type(value)(value.data.to(param_to), **value.__dict__)
val_kwargs = {}
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, tensor_name, value)
# TODO: consider removing used param_parts from state_dict before return