Enable multi-device for more models (#30409)

* feat: support for dinov2

* feat: support for depth_anything

* feat: support for efficientformer

* feat: support for bert (is this right?)

* update: embedding split

* remove: empty string

* feat: support for align

* fix: copies

* fix: QAQBertEmbeddings

* fix: more consistency issues

* revert: support for effientformer

* feat: support for altclip

* feat: support for blip_text

* support for ChineseCLIP

* feat: support for depth anything

* feat: support for dpt

* feat: support for dpt

* feat: support for git

* feat: support for groupvit

* update: format

* fix: support for clip

* fix: consistency

* feat: support for pvt

* feat: support for vit_msn

* fix: consistency

* fix: other copies

* remove: device transfer

* revert: in-place add

* update: support for align

* update: support for bert

* update: support for Chinese CLIP

* revert: changes to efficientformer

* update: support for dpt

* update: support for efficientformer

* revert: changes to git

* revert: changes to groupvit

* revert: changes to roc_bert

* update: support for vit_msn

* revert: changes to dpt

* remove: extra space

* style: extra space
This commit is contained in:
Jacky Lee 2024-04-30 04:09:08 -07:00 committed by GitHub
parent c712d05aa8
commit 0ae789e043
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 12 additions and 0 deletions

View File

@ -1203,6 +1203,7 @@ class AlignPreTrainedModel(PreTrainedModel):
)
class AlignTextModel(AlignPreTrainedModel):
config_class = AlignTextConfig
_no_split_modules = ["AlignTextEmbeddings"]
def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True):
super().__init__(config)

View File

@ -1034,6 +1034,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
config_class = AltCLIPConfig
base_model_prefix = "altclip"
supports_gradient_checkpointing = True
_no_split_module = []
def _init_weights(self, module):
"""Initialize the weights"""

View File

@ -962,6 +962,8 @@ class BertModel(BertPreTrainedModel):
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
"""
_no_split_modules = ["BertEmbeddings"]
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config

View File

@ -1113,6 +1113,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
"""
config_class = ChineseCLIPTextConfig
_no_split_modules = ["ChineseCLIPTextEmbeddings"]
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
@ -1284,6 +1285,7 @@ class ChineseCLIPTextModel(ChineseCLIPPreTrainedModel):
class ChineseCLIPVisionModel(ChineseCLIPPreTrainedModel):
config_class = ChineseCLIPVisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["ChineseCLIPVisionEmbeddings", "ChineseCLIPVisionAttention"]
def __init__(self, config: ChineseCLIPVisionConfig):
super().__init__(config)

View File

@ -364,6 +364,8 @@ class DepthAnythingDepthEstimationHead(nn.Module):
DEPTH_ANYTHING_START_DOCSTRING,
)
class DepthAnythingForDepthEstimation(DepthAnythingPreTrainedModel):
_no_split_modules = ["DPTViTEmbeddings"]
def __init__(self, config):
super().__init__(config)

View File

@ -481,6 +481,7 @@ class Dinov2PreTrainedModel(PreTrainedModel):
base_model_prefix = "dinov2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["Dinov2SwiGLUFFN"]
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@ -555,6 +555,7 @@ class EfficientFormerModel(EfficientFormerPreTrainedModel):
def __init__(self, config: EfficientFormerConfig):
super().__init__(config)
self.config = config
_no_split_modules = ["EfficientFormerMeta4D"]
self.patch_embed = EfficientFormerConvStem(config, config.hidden_sizes[0])
self.encoder = EfficientFormerEncoder(config)

View File

@ -462,6 +462,7 @@ class PvtPreTrainedModel(PreTrainedModel):
config_class = PvtConfig
base_model_prefix = "pvt"
main_input_name = "pixel_values"
_no_split_modules = []
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""

View File

@ -421,6 +421,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["ViTMSNAttention"]
# todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211
# when creating pre-training scripts.