Small fix on context manager detection (#37562)

* small fixes

* Update modeling_utils.py

* test

* Update test_modeling_common.py

* Update test_modeling_timm_backbone.py

* more general

* simpler
This commit is contained in:
Cyril Vallez 2025-04-17 15:39:44 +02:00 committed by GitHub
parent c7d3cc67a1
commit 58e5e976e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 36 additions and 18 deletions

View File

@ -4167,15 +4167,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
_adapter_model_path = None
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
if device_map is None:
if device_map is None and not is_deepspeed_zero3_enabled():
device_in_context = get_torch_context_manager_or_global_device()
if device_in_context == torch.device("meta"):
raise ValueError(
(
"`from_pretrained` is not compatible with a meta device context manager or `torch.set_default_device('meta')` "
"as its purpose is to load weights. If you want to initialize a model on the meta device, use the context manager "
"or global device with `from_config`, or `ModelClass(config)`"
)
# TODO Cyril: raise an error instead of the warning in v4.53 (and change the test to check for raise instead of success)
logger.warning(
"We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`\n"
"This is an anti-pattern and will raise an Error in version v4.53\nIf you want to initialize a model on the meta device, use "
"the context manager or global device with `from_config`, or `ModelClass(config)`"
)
device_map = device_in_context
@ -5834,6 +5833,16 @@ def expand_device_map(device_map, param_names):
return new_device_map
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
a proper `torch.device`.
"""
if device == "disk":
return False
else:
return torch.device(device).type not in ["meta", "cpu"]
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
@ -5853,9 +5862,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
"""
# Remove disk and cpu devices, and cast to proper torch.device
# Remove disk, cpu and meta devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
}
if not len(accelerator_device_map):
return

View File

@ -545,7 +545,7 @@ class NatEncoder(nn.Module):
super().__init__()
self.num_levels = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
self.levels = nn.ModuleList(
[
NatStage(

View File

@ -311,7 +311,9 @@ class VanEncoder(nn.Module):
hidden_sizes = config.hidden_sizes
depths = config.depths
mlp_ratios = config.mlp_ratios
drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
drop_path_rates = [
x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")
]
for num_stage, (patch_size, stride, hidden_size, depth, mlp_expantion, drop_path_rate) in enumerate(
zip(patch_sizes, strides, hidden_sizes, depths, mlp_ratios, drop_path_rates)

View File

@ -553,7 +553,7 @@ class DinatEncoder(nn.Module):
super().__init__()
self.num_levels = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
self.levels = nn.ModuleList(
[
DinatStage(

View File

@ -177,7 +177,7 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
pass
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
def test_cannot_load_with_meta_device_context_manager(self):
def test_can_load_with_meta_device_context_manager(self):
pass
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")

View File

@ -4590,7 +4590,7 @@ class ModelTesterMixin:
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
)
def test_cannot_load_with_meta_device_context_manager(self):
def test_can_load_with_meta_device_context_manager(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
@ -4600,10 +4600,17 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
# This should raise an error with meta device
with self.assertRaises(ValueError, msg="`from_pretrained` is not compatible with a meta device"):
with torch.device("meta"):
_ = model_class.from_pretrained(tmpdirname)
with torch.device("meta"):
new_model = model_class.from_pretrained(tmpdirname)
unique_devices = {param.device for param in new_model.parameters()} | {
buffer.device for buffer in new_model.buffers()
}
self.assertEqual(
unique_devices,
{torch.device("meta")},
f"All parameters should be on meta device, but found {unique_devices}.",
)
global_rng = random.Random()