mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-06 14:20:04 +06:00
Small fix on context manager detection (#37562)
* small fixes * Update modeling_utils.py * test * Update test_modeling_common.py * Update test_modeling_timm_backbone.py * more general * simpler
This commit is contained in:
parent
c7d3cc67a1
commit
58e5e976e0
@ -4167,15 +4167,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
|
|||||||
_adapter_model_path = None
|
_adapter_model_path = None
|
||||||
|
|
||||||
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
|
# Potentially detect context manager or global device, and use it (only if no device_map was provided)
|
||||||
if device_map is None:
|
if device_map is None and not is_deepspeed_zero3_enabled():
|
||||||
device_in_context = get_torch_context_manager_or_global_device()
|
device_in_context = get_torch_context_manager_or_global_device()
|
||||||
if device_in_context == torch.device("meta"):
|
if device_in_context == torch.device("meta"):
|
||||||
raise ValueError(
|
# TODO Cyril: raise an error instead of the warning in v4.53 (and change the test to check for raise instead of success)
|
||||||
(
|
logger.warning(
|
||||||
"`from_pretrained` is not compatible with a meta device context manager or `torch.set_default_device('meta')` "
|
"We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`\n"
|
||||||
"as its purpose is to load weights. If you want to initialize a model on the meta device, use the context manager "
|
"This is an anti-pattern and will raise an Error in version v4.53\nIf you want to initialize a model on the meta device, use "
|
||||||
"or global device with `from_config`, or `ModelClass(config)`"
|
"the context manager or global device with `from_config`, or `ModelClass(config)`"
|
||||||
)
|
|
||||||
)
|
)
|
||||||
device_map = device_in_context
|
device_map = device_in_context
|
||||||
|
|
||||||
@ -5834,6 +5833,16 @@ def expand_device_map(device_map, param_names):
|
|||||||
return new_device_map
|
return new_device_map
|
||||||
|
|
||||||
|
|
||||||
|
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
||||||
|
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
|
||||||
|
a proper `torch.device`.
|
||||||
|
"""
|
||||||
|
if device == "disk":
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return torch.device(device).type not in ["meta", "cpu"]
|
||||||
|
|
||||||
|
|
||||||
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
|
def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, factor=2):
|
||||||
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
|
||||||
@ -5853,9 +5862,9 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
|
|||||||
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
|
- Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
|
||||||
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
|
However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
|
||||||
"""
|
"""
|
||||||
# Remove disk and cpu devices, and cast to proper torch.device
|
# Remove disk, cpu and meta devices, and cast to proper torch.device
|
||||||
accelerator_device_map = {
|
accelerator_device_map = {
|
||||||
param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
|
param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
|
||||||
}
|
}
|
||||||
if not len(accelerator_device_map):
|
if not len(accelerator_device_map):
|
||||||
return
|
return
|
||||||
|
@ -545,7 +545,7 @@ class NatEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_levels = len(config.depths)
|
self.num_levels = len(config.depths)
|
||||||
self.config = config
|
self.config = config
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
self.levels = nn.ModuleList(
|
self.levels = nn.ModuleList(
|
||||||
[
|
[
|
||||||
NatStage(
|
NatStage(
|
||||||
|
@ -311,7 +311,9 @@ class VanEncoder(nn.Module):
|
|||||||
hidden_sizes = config.hidden_sizes
|
hidden_sizes = config.hidden_sizes
|
||||||
depths = config.depths
|
depths = config.depths
|
||||||
mlp_ratios = config.mlp_ratios
|
mlp_ratios = config.mlp_ratios
|
||||||
drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
drop_path_rates = [
|
||||||
|
x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")
|
||||||
|
]
|
||||||
|
|
||||||
for num_stage, (patch_size, stride, hidden_size, depth, mlp_expantion, drop_path_rate) in enumerate(
|
for num_stage, (patch_size, stride, hidden_size, depth, mlp_expantion, drop_path_rate) in enumerate(
|
||||||
zip(patch_sizes, strides, hidden_sizes, depths, mlp_ratios, drop_path_rates)
|
zip(patch_sizes, strides, hidden_sizes, depths, mlp_ratios, drop_path_rates)
|
||||||
|
@ -553,7 +553,7 @@ class DinatEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_levels = len(config.depths)
|
self.num_levels = len(config.depths)
|
||||||
self.config = config
|
self.config = config
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
self.levels = nn.ModuleList(
|
self.levels = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DinatStage(
|
DinatStage(
|
||||||
|
@ -177,7 +177,7 @@ class TimmBackboneModelTest(ModelTesterMixin, BackboneTesterMixin, PipelineTeste
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
|
@unittest.skip(reason="TimmBackbone uses its own `from_pretrained` without device_map support")
|
||||||
def test_cannot_load_with_meta_device_context_manager(self):
|
def test_can_load_with_meta_device_context_manager(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
|
@unittest.skip(reason="model weights aren't tied in TimmBackbone.")
|
||||||
|
@ -4590,7 +4590,7 @@ class ModelTesterMixin:
|
|||||||
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
|
unique_devices, {device}, f"All parameters should be on {device}, but found {unique_devices}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_cannot_load_with_meta_device_context_manager(self):
|
def test_can_load_with_meta_device_context_manager(self):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
|
# Need to deepcopy here as it is modified in-place in save_pretrained (it sets sdpa for default attn, which
|
||||||
@ -4600,10 +4600,17 @@ class ModelTesterMixin:
|
|||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
model.save_pretrained(tmpdirname)
|
model.save_pretrained(tmpdirname)
|
||||||
|
|
||||||
# This should raise an error with meta device
|
|
||||||
with self.assertRaises(ValueError, msg="`from_pretrained` is not compatible with a meta device"):
|
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
_ = model_class.from_pretrained(tmpdirname)
|
new_model = model_class.from_pretrained(tmpdirname)
|
||||||
|
unique_devices = {param.device for param in new_model.parameters()} | {
|
||||||
|
buffer.device for buffer in new_model.buffers()
|
||||||
|
}
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
unique_devices,
|
||||||
|
{torch.device("meta")},
|
||||||
|
f"All parameters should be on meta device, but found {unique_devices}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
Loading…
Reference in New Issue
Block a user