mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
make sure to disable gradients for integer tensor (#32943)
This commit is contained in:
parent
1c471fc307
commit
36759f3312
@ -927,7 +927,10 @@ def _load_state_dict_into_meta_model(
|
||||
param_to = "cpu"
|
||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
||||
param_to = "meta"
|
||||
value = type(value)(value.data.to(param_to), **value.__dict__)
|
||||
val_kwargs = {}
|
||||
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
|
||||
val_kwargs["requires_grad"] = False
|
||||
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
|
||||
setattr(module, tensor_name, value)
|
||||
# TODO: consider removing used param_parts from state_dict before return
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user