mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
All models can be initialized on meta device (#37563)
* Update test_modeling_common.py * fix all * more fixes
This commit is contained in:
parent
0a83588c51
commit
688f4707bf
@ -663,7 +663,7 @@ class BeitEncoder(nn.Module):
|
|||||||
self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
|
self.relative_position_bias = BeitRelativePositionBias(config, window_size=window_size)
|
||||||
|
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
|
||||||
self.layer = nn.ModuleList(
|
self.layer = nn.ModuleList(
|
||||||
[
|
[
|
||||||
BeitLayer(
|
BeitLayer(
|
||||||
|
@ -829,7 +829,7 @@ class ClapAudioEncoder(nn.Module):
|
|||||||
|
|
||||||
self.num_features = int(config.patch_embeds_hidden_size * 2 ** (self.num_layers - 1))
|
self.num_features = int(config.patch_embeds_hidden_size * 2 ** (self.num_layers - 1))
|
||||||
|
|
||||||
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
|
|
||||||
grid_size = self.patch_embed.grid_size
|
grid_size = self.patch_embed.grid_size
|
||||||
self.input_resolutions = [(grid_size[0] // (2**i), grid_size[1] // (2**i)) for i in range(self.num_layers)]
|
self.input_resolutions = [(grid_size[0] // (2**i), grid_size[1] // (2**i)) for i in range(self.num_layers)]
|
||||||
|
@ -225,7 +225,8 @@ class ConvNextEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.stages = nn.ModuleList()
|
self.stages = nn.ModuleList()
|
||||||
drop_path_rates = [
|
drop_path_rates = [
|
||||||
x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
|
x.tolist()
|
||||||
|
for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
|
||||||
]
|
]
|
||||||
prev_chs = config.hidden_sizes[0]
|
prev_chs = config.hidden_sizes[0]
|
||||||
for i in range(config.num_stages):
|
for i in range(config.num_stages):
|
||||||
|
@ -245,7 +245,8 @@ class ConvNextV2Encoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.stages = nn.ModuleList()
|
self.stages = nn.ModuleList()
|
||||||
drop_path_rates = [
|
drop_path_rates = [
|
||||||
x.tolist() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths)).split(config.depths)
|
x.tolist()
|
||||||
|
for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").split(config.depths)
|
||||||
]
|
]
|
||||||
prev_chs = config.hidden_sizes[0]
|
prev_chs = config.hidden_sizes[0]
|
||||||
for i in range(config.num_stages):
|
for i in range(config.num_stages):
|
||||||
|
@ -449,7 +449,9 @@ class CvtStage(nn.Module):
|
|||||||
dropout_rate=config.drop_rate[self.stage],
|
dropout_rate=config.drop_rate[self.stage],
|
||||||
)
|
)
|
||||||
|
|
||||||
drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage])]
|
drop_path_rates = [
|
||||||
|
x.item() for x in torch.linspace(0, config.drop_path_rate[self.stage], config.depth[stage], device="cpu")
|
||||||
|
]
|
||||||
|
|
||||||
self.layers = nn.Sequential(
|
self.layers = nn.Sequential(
|
||||||
*[
|
*[
|
||||||
|
@ -676,7 +676,7 @@ class Data2VecVisionEncoder(nn.Module):
|
|||||||
self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
|
self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
|
||||||
|
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
|
||||||
self.layer = nn.ModuleList(
|
self.layer = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Data2VecVisionLayer(
|
Data2VecVisionLayer(
|
||||||
|
@ -790,7 +790,7 @@ class DonutSwinEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_layers = len(config.depths)
|
self.num_layers = len(config.depths)
|
||||||
self.config = config
|
self.config = config
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
DonutSwinStage(
|
DonutSwinStage(
|
||||||
|
@ -486,7 +486,7 @@ class FocalNetStage(nn.Module):
|
|||||||
downsample = FocalNetPatchEmbeddings if (index < self.num_stages - 1) else None
|
downsample = FocalNetPatchEmbeddings if (index < self.num_stages - 1) else None
|
||||||
|
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
drop_path = dpr[sum(config.depths[:index]) : sum(config.depths[: index + 1])]
|
drop_path = dpr[sum(config.depths[:index]) : sum(config.depths[: index + 1])]
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
|
@ -331,7 +331,7 @@ class GLPNEncoder(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
|
|
||||||
# patch embeddings
|
# patch embeddings
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
@ -639,9 +639,9 @@ class HieraEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
total_depth = sum(config.depths)
|
total_depth = sum(config.depths)
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, total_depth)]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, total_depth, device="cpu")]
|
||||||
# query strides rule
|
# query strides rule
|
||||||
cumulative_depths = torch.tensor(config.depths).cumsum(0).tolist()
|
cumulative_depths = torch.tensor(config.depths, device="cpu").cumsum(0).tolist()
|
||||||
query_pool_layer = cumulative_depths[: config.num_query_pool]
|
query_pool_layer = cumulative_depths[: config.num_query_pool]
|
||||||
query_strides = [math.prod(config.query_stride) if i in query_pool_layer else 1 for i in range(total_depth)]
|
query_strides = [math.prod(config.query_stride) if i in query_pool_layer else 1 for i in range(total_depth)]
|
||||||
|
|
||||||
|
@ -692,7 +692,7 @@ class MaskFormerSwinEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_layers = len(config.depths)
|
self.num_layers = len(config.depths)
|
||||||
self.config = config
|
self.config = config
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
MaskFormerSwinStage(
|
MaskFormerSwinStage(
|
||||||
|
@ -246,7 +246,7 @@ class MgpstrEncoder(nn.Module):
|
|||||||
def __init__(self, config: MgpstrConfig):
|
def __init__(self, config: MgpstrConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
|
||||||
|
|
||||||
self.blocks = nn.Sequential(
|
self.blocks = nn.Sequential(
|
||||||
*[MgpstrLayer(config=config, drop_path=dpr[i]) for i in range(config.num_hidden_layers)]
|
*[MgpstrLayer(config=config, drop_path=dpr[i]) for i in range(config.num_hidden_layers)]
|
||||||
|
@ -194,7 +194,7 @@ class PoolFormerEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
|
|
||||||
# patch embeddings
|
# patch embeddings
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
@ -369,7 +369,7 @@ class PvtEncoder(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths)).tolist()
|
drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").tolist()
|
||||||
|
|
||||||
# patch embeddings
|
# patch embeddings
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
@ -323,7 +323,7 @@ class PvtV2EncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
# Transformer block
|
# Transformer block
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths)).tolist()
|
drop_path_decays = torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu").tolist()
|
||||||
block_layers = []
|
block_layers = []
|
||||||
for block_idx in range(config.depths[layer_idx]):
|
for block_idx in range(config.depths[layer_idx]):
|
||||||
block_layers.append(
|
block_layers.append(
|
||||||
|
@ -356,7 +356,9 @@ class SegformerEncoder(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
drop_path_decays = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
drop_path_decays = [
|
||||||
|
x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")
|
||||||
|
]
|
||||||
|
|
||||||
# patch embeddings
|
# patch embeddings
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
@ -460,7 +460,7 @@ class SegGptEncoder(nn.Module):
|
|||||||
def __init__(self, config: SegGptConfig) -> None:
|
def __init__(self, config: SegGptConfig) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")]
|
||||||
self.layers = nn.ModuleList([SegGptLayer(config, dpr[i]) for i in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([SegGptLayer(config, dpr[i]) for i in range(config.num_hidden_layers)])
|
||||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
@ -823,7 +823,7 @@ class SwinEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_layers = len(config.depths)
|
self.num_layers = len(config.depths)
|
||||||
self.config = config
|
self.config = config
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
SwinStage(
|
SwinStage(
|
||||||
|
@ -682,7 +682,7 @@ class Swin2SREncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_stages = len(config.depths)
|
self.num_stages = len(config.depths)
|
||||||
self.config = config
|
self.config = config
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
self.stages = nn.ModuleList(
|
self.stages = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Swin2SRStage(
|
Swin2SRStage(
|
||||||
|
@ -877,7 +877,7 @@ class Swinv2Encoder(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
if self.config.pretrained_window_sizes is not None:
|
if self.config.pretrained_window_sizes is not None:
|
||||||
pretrained_window_sizes = config.pretrained_window_sizes
|
pretrained_window_sizes = config.pretrained_window_sizes
|
||||||
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
|
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
for i_layer in range(self.num_layers):
|
for i_layer in range(self.num_layers):
|
||||||
|
@ -295,7 +295,7 @@ class TimesformerLayer(nn.Module):
|
|||||||
attention_type = config.attention_type
|
attention_type = config.attention_type
|
||||||
|
|
||||||
drop_path_rates = [
|
drop_path_rates = [
|
||||||
x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
|
x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers, device="cpu")
|
||||||
] # stochastic depth decay rule
|
] # stochastic depth decay rule
|
||||||
drop_path_rate = drop_path_rates[layer_index]
|
drop_path_rate = drop_path_rates[layer_index]
|
||||||
|
|
||||||
|
@ -535,7 +535,7 @@ class VitDetEncoder(nn.Module):
|
|||||||
depth = config.num_hidden_layers
|
depth = config.num_hidden_layers
|
||||||
|
|
||||||
# stochastic depth decay rule
|
# stochastic depth decay rule
|
||||||
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, depth)]
|
drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, depth, device="cpu")]
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
for i in range(depth):
|
for i in range(depth):
|
||||||
|
@ -4528,6 +4528,13 @@ class ModelTesterMixin:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_can_be_initialized_on_meta(self):
|
||||||
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
# If it does not raise here, the test passes
|
||||||
|
with torch.device("meta"):
|
||||||
|
_ = model_class(config)
|
||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
def test_can_load_with_device_context_manager(self):
|
def test_can_load_with_device_context_manager(self):
|
||||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
Loading…
Reference in New Issue
Block a user