diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b1ec6896be8..da2f07c0a66 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4167,15 +4167,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi _adapter_model_path = None # Potentially detect context manager or global device, and use it (only if no device_map was provided) - if device_map is None: + if device_map is None and not is_deepspeed_zero3_enabled(): device_in_context = get_torch_context_manager_or_global_device() if device_in_context == torch.device("meta"): - raise ValueError( - ( - "`from_pretrained` is not compatible with a meta device context manager or `torch.set_default_device('meta')` " - "as its purpose is to load weights. If you want to initialize a model on the meta device, use the context manager " - "or global device with `from_config`, or `ModelClass(config)`" - ) + # TODO Cyril: raise an error instead of the warning in v4.53 (and change the test to check for raise instead of success) + logger.warning( + "We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`\n" + "This is an anti-pattern and will raise an Error in version v4.53\nIf you want to initialize a model on the meta device, use " + "the context manager or global device with `from_config`, or `ModelClass(config)`" ) device_map = device_in_context @@ -5834,6 +5833,16 @@ def expand_device_map(device_map, param_names): return new_device_map +def is_accelerator_device(device: Union[str, int, torch.device]) -> bool: + """Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not + a proper `torch.device`. + """ + if device == "disk": + return False + else: + return torch.device(device).type not in ["meta", "cpu"] + + def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2): """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each device. It allows to have one large call to Malloc, instead of recursively calling it later when loading @@ -5853,9 +5862,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices. However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end. """ - # Remove disk and cpu devices, and cast to proper torch.device + # Remove disk, cpu and meta devices, and cast to proper torch.device accelerator_device_map = { - param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"] + param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device) } if not len(accelerator_device_map): return diff --git a/src/transformers/models/deprecated/nat/modeling_nat.py b/src/transformers/models/deprecated/nat/modeling_nat.py index 5871b03299a..70ecffcf51e 100644 --- a/src/transformers/models/deprecated/nat/modeling_nat.py +++ b/src/transformers/models/deprecated/nat/modeling_nat.py @@ -545,7 +545,7 @@ class NatEncoder(nn.Module): super().__init__() self.num_levels = len(config.depths) self.config = config - dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] self.levels = nn.ModuleList( [ NatStage( diff --git a/src/transformers/models/deprecated/van/modeling_van.py b/src/transformers/models/deprecated/van/modeling_van.py index fd11a04ec21..1da03cb544d 100644 --- a/src/transformers/models/deprecated/van/modeling_van.py +++ b/src/transformers/models/deprecated/van/modeling_van.py @@ -311,7 +311,9 @@ class VanEncoder(nn.Module): hidden_sizes = config.hidden_sizes depths = config.depths mlp_ratios = config.mlp_ratios - drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] + drop_path_rates = [ + x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu") + ] for num_stage, (patch_size, stride, hidden_size, depth, mlp_expantion, drop_path_rate) in enumerate( zip(patch_sizes, strides, hidden_sizes, depths, mlp_ratios, drop_path_rates) diff --git a/src/transformers/models/dinat/modeling_dinat.py b/src/transformers/models/dinat/modeling_dinat.py index 0e0121b78da..8837372a84c 100644 --- a/src/transformers/models/dinat/modeling_dinat.py +++ b/src/transformers/models/dinat/modeling_dinat.py @@ -553,7 +553,7 @@ class DinatEncoder(nn.Module): super().__init__() self.num_levels = len(config.depths) self.config = config - dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")] self.levels = nn.ModuleList( [ DinatStage( diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index 582bdab0b58..d060ab38886 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -177,7 +177,7 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste pass @unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support") - def test_cannot_load_with_meta_device_context_manager(self): + def test_can_load_with_meta_device_context_manager(self): pass @unittest.skip(reason="model weights aren't tied in TimmBackbone.") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e546d023f5e..fca89147f42 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4590,7 +4590,7 @@ class ModelTesterMixin: unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}." ) - def test_cannot_load_with_meta_device_context_manager(self): + def test_can_load_with_meta_device_context_manager(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: # Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which @@ -4600,10 +4600,17 @@ class ModelTesterMixin: with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - # This should raise an error with meta device - with self.assertRaises(ValueError, msg="`from_pretrained` is not compatible with a meta device"): - with torch.device("meta"): - _ = model_class.from_pretrained(tmpdirname) + with torch.device("meta"): + new_model = model_class.from_pretrained(tmpdirname) + unique_devices = {param.device for param in new_model.parameters()} | { + buffer.device for buffer in new_model.buffers() + } + + self.assertEqual( + unique_devices, + {torch.device("meta")}, + f"All parameters should be on meta device, but found {unique_devices}.", + ) global_rng = random.Random()