Fix _load_state_dict_into_meta_model with device_map=None (#36488)

* Fix _load_state_dict_into_meta_model with device_map=None

* Update src/transformers/modeling_utils.py
This commit is contained in:
hlky 2025-03-02 07:33:36 +00:00 committed by GitHub
parent a40f1ac602
commit dcbdf7e962
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -785,8 +785,8 @@ def _load_state_dict_into_meta_model(
tensor_device = None
if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
if device_map is not None:
device_map_regex = "|".join(sorted(device_map.keys(), reverse=True))
# we need this later to initialize tensor parallelism
if device_mesh is not None: