Add accelerate support for ViT family (#20174)

* add `accelerate` support for `ViT` family

- add `_no_split_modules`
- manually cast to the right `dtype`: to change

* enable `float16` for `deit`

* fix `make fixup`

* add `slow` test for `fp16` inference

* another safety check

* Update src/transformers/models/deit/modeling_deit.py
This commit is contained in:
Younes Belkada 2022-11-15 11:06:01 +01:00 committed by GitHub
parent 11b2e45ccc
commit f1e8c48c5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 82 additions and 6 deletions

View File

@ -398,11 +398,16 @@ class DeiTPreTrainedModel(PreTrainedModel):
base_model_prefix = "deit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
@ -511,6 +516,11 @@ class DeiTModel(DeiTPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
encoder_outputs = self.encoder(

View File

@ -68,14 +68,18 @@ class ViTEmbeddings(nn.Module):
super().__init__()
self.cls_token = nn.Parameter(
nn.init.trunc_normal_(torch.zeros(1, 1, config.hidden_size), mean=0.0, std=config.initializer_range)
nn.init.trunc_normal_(
torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range
)
)
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(
nn.init.trunc_normal_(
torch.zeros(1, num_patches + 1, config.hidden_size), mean=0.0, std=config.initializer_range
torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32),
mean=0.0,
std=config.initializer_range,
)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
@ -442,11 +446,16 @@ class ViTPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
@ -558,6 +567,11 @@ class ViTModel(ViTPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
if pixel_values.dtype != expected_dtype:
pixel_values = pixel_values.to(expected_dtype)
embedding_output = self.embeddings(
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
)

View File

@ -21,7 +21,14 @@ import warnings
from transformers import DeiTConfig
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
require_accelerate,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@ -394,3 +401,23 @@ class DeiTModelIntegrationTest(unittest.TestCase):
expected_slice = torch.tensor([-1.0266, 0.1912, -1.2861]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
@slow
@require_accelerate
@require_torch_gpu
def test_inference_fp16(self):
r"""
A small test to make sure that inference work in half precision without any problem.
"""
model = DeiTModel.from_pretrained(
"facebook/deit-base-distilled-patch16-224", torch_dtype=torch.float16, device_map="auto"
)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass to make sure inference works in fp16
with torch.no_grad():
_ = model(pixel_values)

View File

@ -19,7 +19,14 @@ import inspect
import unittest
from transformers import ViTConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.testing_utils import (
require_accelerate,
require_torch,
require_torch_gpu,
require_vision,
slow,
torch_device,
)
from transformers.utils import cached_property, is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
@ -300,3 +307,21 @@ class ViTModelIntegrationTest(unittest.TestCase):
).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
@slow
@require_accelerate
@require_torch_gpu
def test_inference_fp16(self):
r"""
A small test to make sure that inference work in half precision without any problem.
"""
model = ViTModel.from_pretrained("facebook/dino-vits8", torch_dtype=torch.float16, device_map="auto")
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass to make sure inference works in fp16
with torch.no_grad():
_ = model(pixel_values)