From 0ae789e04330e15a90e34cd723c851a8ab8d7ec5 Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Tue, 30 Apr 2024 04:09:08 -0700 Subject: [PATCH] Enable multi-device for more models (#30409) * feat: support for dinov2 * feat: support for depth_anything * feat: support for efficientformer * feat: support for bert (is this right?) * update: embedding split * remove: empty string * feat: support for align * fix: copies * fix: QAQBertEmbeddings * fix: more consistency issues * revert: support for effientformer * feat: support for altclip * feat: support for blip_text * support for ChineseCLIP * feat: support for depth anything * feat: support for dpt * feat: support for dpt * feat: support for git * feat: support for groupvit * update: format * fix: support for clip * fix: consistency * feat: support for pvt * feat: support for vit_msn * fix: consistency * fix: other copies * remove: device transfer * revert: in-place add * update: support for align * update: support for bert * update: support for Chinese CLIP * revert: changes to efficientformer * update: support for dpt * update: support for efficientformer * revert: changes to git * revert: changes to groupvit * revert: changes to roc_bert * update: support for vit_msn * revert: changes to dpt * remove: extra space * style: extra space --- src/transformers/models/align/modeling_align.py | 1 + src/transformers/models/altclip/modeling_altclip.py | 1 + src/transformers/models/bert/modeling_bert.py | 2 ++ src/transformers/models/chinese_clip/modeling_chinese_clip.py | 2 ++ .../models/depth_anything/modeling_depth_anything.py | 2 ++ src/transformers/models/dinov2/modeling_dinov2.py | 1 + .../models/efficientformer/modeling_efficientformer.py | 1 + src/transformers/models/pvt/modeling_pvt.py | 1 + src/transformers/models/vit_msn/modeling_vit_msn.py | 1 + 9 files changed, 12 insertions(+) diff --git a/src/transformers/models/align/modeling_align.py b/src/transformers/models/align/modeling_align.py index 0f8246e8f98..4fa128a5f67 100644 --- a/src/transformers/models/align/modeling_align.py +++ b/src/transformers/models/align/modeling_align.py @@ -1203,6 +1203,7 @@ class AlignPreTrainedModel(PreTrainedModel): ) class AlignTextModel(AlignPreTrainedModel): config_class = AlignTextConfig + _no_split_modules = ["AlignTextEmbeddings"] def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True): super().__init__(config) diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index ba8abb311a8..3e184085331 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -1034,6 +1034,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel): config_class = AltCLIPConfig base_model_prefix = "altclip" supports_gradient_checkpointing = True + _no_split_module = [] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index f7af0f1ef5a..9e2847b11b5 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -962,6 +962,8 @@ class BertModel(BertPreTrainedModel): `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ + _no_split_modules = ["BertEmbeddings"] + def __init__(self, config, add_pooling_layer=True): super().__init__(config) self.config = config diff --git a/src/transformers/models/chinese_clip/modeling_chinese_clip.py b/src/transformers/models/chinese_clip/modeling_chinese_clip.py index 87a1baa217b..7d5c8f2fcc8 100644 --- a/src/transformers/models/chinese_clip/modeling_chinese_clip.py +++ b/src/transformers/models/chinese_clip/modeling_chinese_clip.py @@ -1113,6 +1113,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel): """ config_class = ChineseCLIPTextConfig + _no_split_modules = ["ChineseCLIPTextEmbeddings"] def __init__(self, config, add_pooling_layer=True): super().__init__(config) @@ -1284,6 +1285,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel): class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel): config_class = ChineseCLIPVisionConfig main_input_name = "pixel_values" + _no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"] def __init__(self, config: ChineseCLIPVisionConfig): super().__init__(config) diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index bed91ac2a48..043bd0fac80 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -364,6 +364,8 @@ class DepthAnythingDepthEstimationHead(nn.Module): DEPTH_ANYTHING_START_DOCSTRING, ) class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel): + _no_split_modules = ["DPTViTEmbeddings"] + def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/dinov2/modeling_dinov2.py b/src/transformers/models/dinov2/modeling_dinov2.py index c25022f6ec2..c90221f145d 100644 --- a/src/transformers/models/dinov2/modeling_dinov2.py +++ b/src/transformers/models/dinov2/modeling_dinov2.py @@ -481,6 +481,7 @@ class Dinov2PreTrainedModel(PreTrainedModel): base_model_prefix = "dinov2" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["Dinov2SwiGLUFFN"] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/efficientformer/modeling_efficientformer.py b/src/transformers/models/efficientformer/modeling_efficientformer.py index 70075cff55d..cc62e9cbd21 100644 --- a/src/transformers/models/efficientformer/modeling_efficientformer.py +++ b/src/transformers/models/efficientformer/modeling_efficientformer.py @@ -555,6 +555,7 @@ class EfficientFormerModel(EfficientFormerPreTrainedModel): def __init__(self, config: EfficientFormerConfig): super().__init__(config) self.config = config + _no_split_modules = ["EfficientFormerMeta4D"] self.patch_embed = EfficientFormerConvStem(config, config.hidden_sizes[0]) self.encoder = EfficientFormerEncoder(config) diff --git a/src/transformers/models/pvt/modeling_pvt.py b/src/transformers/models/pvt/modeling_pvt.py index b169af0cbd5..4574ca37876 100755 --- a/src/transformers/models/pvt/modeling_pvt.py +++ b/src/transformers/models/pvt/modeling_pvt.py @@ -462,6 +462,7 @@ class PvtPreTrainedModel(PreTrainedModel): config_class = PvtConfig base_model_prefix = "pvt" main_input_name = "pixel_values" + _no_split_modules = [] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 9c2269a3ae5..424d657dc87 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -421,6 +421,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel): base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["ViTMSNAttention"] # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 # when creating pre-training scripts.