mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Revert "fixes to properly shard FSDP across cpu and meta for cpu_effcient_loading for prequantized 4bit (#32276)" (#32477)
* Revert "fixes to properly shard FSDP across cpu and meta for cpu_efficient_loading for prequantized 4bit (#32276)"
This reverts commit 62c60a3018
.
We uncovered an issue with this change that caused our training runs to hang.
* `is_torchdynamo_compiling` -- cast a wide exception net (#32476)
* cast a wide net
* make fix-copies with a few manual changes
* add copied from
---------
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
4fdc7020b2
commit
ac2707e8ee
@ -933,8 +933,6 @@ def _load_state_dict_into_meta_model(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if is_fsdp_enabled():
|
|
||||||
param_device = "cpu" if is_local_dist_rank_0() else "meta"
|
|
||||||
# For backward compatibility with older versions of `accelerate` and for non-quantized params
|
# For backward compatibility with older versions of `accelerate` and for non-quantized params
|
||||||
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
|
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
|
||||||
else:
|
else:
|
||||||
@ -945,10 +943,7 @@ def _load_state_dict_into_meta_model(
|
|||||||
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
|
||||||
module, tensor_name = get_module_from_name(model, param_name)
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
value = getattr(module, tensor_name)
|
value = getattr(module, tensor_name)
|
||||||
param_to = "cpu"
|
value = type(value)(value.data.to("cpu"), **value.__dict__)
|
||||||
if is_fsdp_enabled() and not is_local_dist_rank_0():
|
|
||||||
param_to = "meta"
|
|
||||||
value = type(value)(value.data.to(param_to), **value.__dict__)
|
|
||||||
setattr(module, tensor_name, value)
|
setattr(module, tensor_name, value)
|
||||||
# TODO: consider removing used param_parts from state_dict before return
|
# TODO: consider removing used param_parts from state_dict before return
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -208,16 +207,11 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|||||||
if unexpected_keys is not None and k in unexpected_keys:
|
if unexpected_keys is not None and k in unexpected_keys:
|
||||||
unexpected_keys.remove(k)
|
unexpected_keys.remove(k)
|
||||||
|
|
||||||
param_kwargs = {}
|
|
||||||
sig = inspect.signature(bnb.nn.Params4bit.from_prequantized)
|
|
||||||
if "module" in sig.parameters:
|
|
||||||
param_kwargs["module"] = module
|
|
||||||
new_value = bnb.nn.Params4bit.from_prequantized(
|
new_value = bnb.nn.Params4bit.from_prequantized(
|
||||||
data=param_value,
|
data=param_value,
|
||||||
quantized_stats=quantized_stats,
|
quantized_stats=quantized_stats,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
device=target_device,
|
device=target_device,
|
||||||
**param_kwargs,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_value = param_value.to("cpu")
|
new_value = param_value.to("cpu")
|
||||||
|
Loading…
Reference in New Issue
Block a user