diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 0f8246e8f98..4fa128a5f67 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1203,6 +1203,7 @@ class AlignPreTrainedModel(PreTrainedModel): ) class AlignTextModel(AlignPreTrainedModel): config_class = AlignTextConfig + _no_split_modules = ["AlignTextEmbeddings"] def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True): super().__init__(config) diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index ba8abb311a8..3e184085331 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -1034,6 +1034,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel): config_class = AltCLIPConfig base_model_prefix = "altclip" supports_gradient_checkpointing = True + _no_split_module = [] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index f7af0f1ef5a..9e2847b11b5 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -962,6 +962,8 @@ class BertModel(BertPreTrainedModel): `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ + _no_split_modules = ["BertEmbeddings"] + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 87a1baa217b..7d5c8f2fcc8 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -1113,6 +1113,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel): """ config_class = ChineseCLIPTextConfig + _no_split_modules = ["ChineseCLIPTextEmbeddings"] def __init__(self, config, add_pooling_layer=True): super().__init__(config) @@ -1284,6 +1285,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel): class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel): config_class = ChineseCLIPVisionConfig main_input_name = "pixel_values" + _no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"] def __init__(self, config: ChineseCLIPVisionConfig): super().__init__(config) diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index bed91ac2a48..043bd0fac80 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -364,6 +364,8 @@ class DepthAnythingDepthEstimationHead(nn.Module): DEPTH_ANYTHING_START_DOCSTRING, ) class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel): + _no_split_modules = ["DPTViTEmbeddings"] + def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index c25022f6ec2..c90221f145d 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -481,6 +481,7 @@ class Dinov2PreTrainedModel(PreTrainedModel): base_model_prefix = "dinov2" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["Dinov2SwiGLUFFN"] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/efficientformer/modeling_efficientformer.py b/src/transformers/models/efficientformer/modeling_efficientformer.py index 70075cff55d..cc62e9cbd21 100644 --- a/src/transformers/models/efficientformer/modeling_efficientformer.py +++ b/src/transformers/models/efficientformer/modeling_efficientformer.py @@ -555,6 +555,7 @@ class EfficientFormerModel(EfficientFormerPreTrainedModel): def __init__(self, config: EfficientFormerConfig): super().__init__(config) self.config = config + _no_split_modules = ["EfficientFormerMeta4D"] self.patch_embed = EfficientFormerConvStem(config, config.hidden_sizes[0]) self.encoder = EfficientFormerEncoder(config) diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index b169af0cbd5..4574ca37876 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -462,6 +462,7 @@ class PvtPreTrainedModel(PreTrainedModel): config_class = PvtConfig base_model_prefix = "pvt" main_input_name = "pixel_values" + _no_split_modules = [] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 9c2269a3ae5..424d657dc87 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -421,6 +421,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel): base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["ViTMSNAttention"] # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # when creating pre-training scripts.