Small fix on context manager detection (#37562)

* small fixes

* Update modeling_utils.py

* test

* Update test_modeling_common.py

* Update test_modeling_timm_backbone.py

* more general

* simpler
This commit is contained in:
Cyril Vallez 2025-04-17 15:39:44 +02:00 committed by GitHub
parent c7d3cc67a1
commit 58e5e976e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 36 additions and 18 deletions

View File

@ -4167,15 +4167,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
_adapter_model_path = None _adapter_model_path = None
# Potentially detect context manager or global device, and use it (only if no device_map was provided) # Potentially detect context manager or global device, and use it (only if no device_map was provided)
if device_map is None: if device_map is None and not is_deepspeed_zero3_enabled():
device_in_context = get_torch_context_manager_or_global_device() device_in_context = get_torch_context_manager_or_global_device()
if device_in_context == torch.device("meta"): if device_in_context == torch.device("meta"):
raise ValueError( # TODO Cyril: raise an error instead of the warning in v4.53 (and change the test to check for raise instead of success)
( logger.warning(
"`from_pretrained` is not compatible with a meta device context manager or `torch.set_default_device('meta')` " "We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`\n"
"as its purpose is to load weights. If you want to initialize a model on the meta device, use the context manager " "This is an anti-pattern and will raise an Error in version v4.53\nIf you want to initialize a model on the meta device, use "
"or global device with `from_config`, or `ModelClass(config)`" "the context manager or global device with `from_config`, or `ModelClass(config)`"
)
) )
device_map = device_in_context device_map = device_in_context
@ -5834,6 +5833,16 @@ def expand_device_map(device_map, param_names):
return new_device_map return new_device_map
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
a proper `torch.device`.
"""
if device == "disk":
return False
else:
return torch.device(device).type not in ["meta", "cpu"]
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2): def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
@ -5853,9 +5862,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices. - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end. However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
""" """
# Remove disk and cpu devices, and cast to proper torch.device # Remove disk, cpu and meta devices, and cast to proper torch.device
accelerator_device_map = { accelerator_device_map = {
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"] param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
} }
if not len(accelerator_device_map): if not len(accelerator_device_map):
return return

View File

@ -545,7 +545,7 @@ class NatEncoder(nn.Module):
super().__init__() super().__init__()
self.num_levels = len(config.depths) self.num_levels = len(config.depths)
self.config = config self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
self.levels = nn.ModuleList( self.levels = nn.ModuleList(
[ [
NatStage( NatStage(

View File

@ -311,7 +311,9 @@ class VanEncoder(nn.Module):
hidden_sizes = config.hidden_sizes hidden_sizes = config.hidden_sizes
depths = config.depths depths = config.depths
mlp_ratios = config.mlp_ratios mlp_ratios = config.mlp_ratios
drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] drop_path_rates = [
x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")
]
for num_stage, (patch_size, stride, hidden_size, depth, mlp_expantion, drop_path_rate) in enumerate( for num_stage, (patch_size, stride, hidden_size, depth, mlp_expantion, drop_path_rate) in enumerate(
zip(patch_sizes, strides, hidden_sizes, depths, mlp_ratios, drop_path_rates) zip(patch_sizes, strides, hidden_sizes, depths, mlp_ratios, drop_path_rates)

View File

@ -553,7 +553,7 @@ class DinatEncoder(nn.Module):
super().__init__() super().__init__()
self.num_levels = len(config.depths) self.num_levels = len(config.depths)
self.config = config self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
self.levels = nn.ModuleList( self.levels = nn.ModuleList(
[ [
DinatStage( DinatStage(

View File

@ -177,7 +177,7 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
pass pass
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support") @unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
def test_cannot_load_with_meta_device_context_manager(self): def test_can_load_with_meta_device_context_manager(self):
pass pass
@unittest.skip(reason="model weights aren't tied in TimmBackbone.") @unittest.skip(reason="model weights aren't tied in TimmBackbone.")

View File

@ -4590,7 +4590,7 @@ class ModelTesterMixin:
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}." unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
) )
def test_cannot_load_with_meta_device_context_manager(self): def test_can_load_with_meta_device_context_manager(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which # Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
@ -4600,10 +4600,17 @@ class ModelTesterMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
# This should raise an error with meta device
with self.assertRaises(ValueError, msg="`from_pretrained` is not compatible with a meta device"):
with torch.device("meta"): with torch.device("meta"):
_ = model_class.from_pretrained(tmpdirname) new_model = model_class.from_pretrained(tmpdirname)
unique_devices = {param.device for param in new_model.parameters()} | {
buffer.device for buffer in new_model.buffers()
}
self.assertEqual(
unique_devices,
{torch.device("meta")},
f"All parameters should be on meta device, but found {unique_devices}.",
)
global_rng = random.Random() global_rng = random.Random()