All models can be initialized on meta device (#37563)

* Update test_modeling_common.py

* fix all

* more fixes
This commit is contained in:
Cyril Vallez 2025-04-16 23:26:44 +02:00 committed by GitHub
parent 0a83588c51
commit 688f4707bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 36 additions and 23 deletions

View File

@ -663,7 +663,7 @@ class BeitEncoder(nn.Module):
self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
self.layer = nn.ModuleList(
[
BeitLayer(

View File

@ -829,7 +829,7 @@ class ClapAudioEncoder(nn.Module):
self.num_features = int(config.patch_embeds_hidden_size * 2 ** (self.num_layers - 1))
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
grid_size = self.patch_embed.grid_size
self.input_resolutions = [(grid_size[0] // (2**i), grid_size[1] // (2**i)) for i in range(self.num_layers)]

View File

@ -225,7 +225,8 @@ class ConvNextEncoder(nn.Module):
super().__init__()
self.stages = nn.ModuleList()
drop_path_rates = [
x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
x.tolist()
for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
]
prev_chs = config.hidden_sizes[0]
for i in range(config.num_stages):

View File

@ -245,7 +245,8 @@ class ConvNextV2Encoder(nn.Module):
super().__init__()
self.stages = nn.ModuleList()
drop_path_rates = [
x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
x.tolist()
for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
]
prev_chs = config.hidden_sizes[0]
for i in range(config.num_stages):

View File

@ -449,7 +449,9 @@ class CvtStage(nn.Module):
dropout_rate=config.drop_rate[self.stage],
)
drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage])]
drop_path_rates = [
x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage], device="cpu")
]
self.layers = nn.Sequential(
*[

View File

@ -676,7 +676,7 @@ class Data2VecVisionEncoder(nn.Module):
self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
self.layer = nn.ModuleList(
[
Data2VecVisionLayer(

View File

@ -790,7 +790,7 @@ class DonutSwinEncoder(nn.Module):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
self.layers = nn.ModuleList(
[
DonutSwinStage(

View File

@ -486,7 +486,7 @@ class FocalNetStage(nn.Module):
downsample = FocalNetPatchEmbeddings if (index < self.num_stages - 1) else None
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
drop_path = dpr[sum(config.depths[:index]) : sum(config.depths[: index + 1])]
self.layers = nn.ModuleList(

View File

@ -331,7 +331,7 @@ class GLPNEncoder(nn.Module):
self.config = config
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
# patch embeddings
embeddings = []

View File

@ -639,9 +639,9 @@ class HieraEncoder(nn.Module):
super().__init__()
total_depth = sum(config.depths)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, total_depth)]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, total_depth, device="cpu")]
# query strides rule
cumulative_depths = torch.tensor(config.depths).cumsum(0).tolist()
cumulative_depths = torch.tensor(config.depths, device="cpu").cumsum(0).tolist()
query_pool_layer = cumulative_depths[: config.num_query_pool]
query_strides = [math.prod(config.query_stride) if i in query_pool_layer else 1 for i in range(total_depth)]

View File

@ -692,7 +692,7 @@ class MaskFormerSwinEncoder(nn.Module):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
self.layers = nn.ModuleList(
[
MaskFormerSwinStage(

View File

@ -246,7 +246,7 @@ class MgpstrEncoder(nn.Module):
def __init__(self, config: MgpstrConfig):
super().__init__()
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
self.blocks = nn.Sequential(
*[MgpstrLayer(config=config, drop_path=dpr[i]) for i in range(config.num_hidden_layers)]

View File

@ -194,7 +194,7 @@ class PoolFormerEncoder(nn.Module):
super().__init__()
self.config = config
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
# patch embeddings
embeddings = []

View File

@ -369,7 +369,7 @@ class PvtEncoder(nn.Module):
self.config = config
# stochastic depth decay rule
drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths)).tolist()
drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").tolist()
# patch embeddings
embeddings = []

View File

@ -323,7 +323,7 @@ class PvtV2EncoderLayer(nn.Module):
)
# Transformer block
# stochastic depth decay rule
drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths)).tolist()
drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").tolist()
block_layers = []
for block_idx in range(config.depths[layer_idx]):
block_layers.append(

View File

@ -356,7 +356,9 @@ class SegformerEncoder(nn.Module):
self.config = config
# stochastic depth decay rule
drop_path_decays = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
drop_path_decays = [
x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")
]
# patch embeddings
embeddings = []

View File

@ -460,7 +460,7 @@ class SegGptEncoder(nn.Module):
def __init__(self, config: SegGptConfig) -> None:
super().__init__()
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
self.layers = nn.ModuleList([SegGptLayer(config, dpr[i]) for i in range(config.num_hidden_layers)])
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False

View File

@ -823,7 +823,7 @@ class SwinEncoder(nn.Module):
super().__init__()
self.num_layers = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
self.layers = nn.ModuleList(
[
SwinStage(

View File

@ -682,7 +682,7 @@ class Swin2SREncoder(nn.Module):
super().__init__()
self.num_stages = len(config.depths)
self.config = config
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
self.stages = nn.ModuleList(
[
Swin2SRStage(

View File

@ -877,7 +877,7 @@ class Swinv2Encoder(nn.Module):
self.config = config
if self.config.pretrained_window_sizes is not None:
pretrained_window_sizes = config.pretrained_window_sizes
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
layers = []
for i_layer in range(self.num_layers):

View File

@ -295,7 +295,7 @@ class TimesformerLayer(nn.Module):
attention_type = config.attention_type
drop_path_rates = [
x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")
] # stochastic depth decay rule
drop_path_rate = drop_path_rates[layer_index]

View File

@ -535,7 +535,7 @@ class VitDetEncoder(nn.Module):
depth = config.num_hidden_layers
# stochastic depth decay rule
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, depth)]
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, depth, device="cpu")]
layers = []
for i in range(depth):

View File

@ -4528,6 +4528,13 @@ class ModelTesterMixin:
),
)
def test_can_be_initialized_on_meta(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# If it does not raise here, the test passes
with torch.device("meta"):
_ = model_class(config)
@require_torch_accelerator
def test_can_load_with_device_context_manager(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()