mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 06:48:58 +06:00
Add accelerate
support for ViT
family (#20174)
* add `accelerate` support for `ViT` family - add `_no_split_modules` - manually cast to the right `dtype`: to change * enable `float16` for `deit` * fix `make fixup` * add `slow` test for `fp16` inference * another safety check * Update src/transformers/models/deit/modeling_deit.py
This commit is contained in:
parent
11b2e45ccc
commit
f1e8c48c5e
@ -398,11 +398,16 @@ class DeiTPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "deit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
|
||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||
# `trunc_normal_cpu` not implemented in `half` issues
|
||||
module.weight.data = nn.init.trunc_normal_(
|
||||
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||||
).to(module.weight.dtype)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
@ -511,6 +516,11 @@ class DeiTModel(DeiTPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
||||
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
||||
if pixel_values.dtype != expected_dtype:
|
||||
pixel_values = pixel_values.to(expected_dtype)
|
||||
|
||||
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
|
@ -68,14 +68,18 @@ class ViTEmbeddings(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.cls_token = nn.Parameter(
|
||||
nn.init.trunc_normal_(torch.zeros(1, 1, config.hidden_size), mean=0.0, std=config.initializer_range)
|
||||
nn.init.trunc_normal_(
|
||||
torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range
|
||||
)
|
||||
)
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
||||
self.patch_embeddings = ViTPatchEmbeddings(config)
|
||||
num_patches = self.patch_embeddings.num_patches
|
||||
self.position_embeddings = nn.Parameter(
|
||||
nn.init.trunc_normal_(
|
||||
torch.zeros(1, num_patches + 1, config.hidden_size), mean=0.0, std=config.initializer_range
|
||||
torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32),
|
||||
mean=0.0,
|
||||
std=config.initializer_range,
|
||||
)
|
||||
)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
@ -442,11 +446,16 @@ class ViTPreTrainedModel(PreTrainedModel):
|
||||
base_model_prefix = "vit"
|
||||
main_input_name = "pixel_values"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = []
|
||||
|
||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
|
||||
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||
# `trunc_normal_cpu` not implemented in `half` issues
|
||||
module.weight.data = nn.init.trunc_normal_(
|
||||
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||||
).to(module.weight.dtype)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
@ -558,6 +567,11 @@ class ViTModel(ViTPreTrainedModel):
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
||||
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
||||
if pixel_values.dtype != expected_dtype:
|
||||
pixel_values = pixel_values.to(expected_dtype)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||
)
|
||||
|
@ -21,7 +21,14 @@ import warnings
|
||||
|
||||
from transformers import DeiTConfig
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -394,3 +401,23 @@ class DeiTModelIntegrationTest(unittest.TestCase):
|
||||
expected_slice = torch.tensor([-1.0266, 0.1912, -1.2861]).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
def test_inference_fp16(self):
|
||||
r"""
|
||||
A small test to make sure that inference work in half precision without any problem.
|
||||
"""
|
||||
model = DeiTModel.from_pretrained(
|
||||
"facebook/deit-base-distilled-patch16-224", torch_dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
feature_extractor = self.default_feature_extractor
|
||||
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# forward pass to make sure inference works in fp16
|
||||
with torch.no_grad():
|
||||
_ = model(pixel_values)
|
||||
|
@ -19,7 +19,14 @@ import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import ViTConfig
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_vision,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@ -300,3 +307,21 @@ class ViTModelIntegrationTest(unittest.TestCase):
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
@slow
|
||||
@require_accelerate
|
||||
@require_torch_gpu
|
||||
def test_inference_fp16(self):
|
||||
r"""
|
||||
A small test to make sure that inference work in half precision without any problem.
|
||||
"""
|
||||
model = ViTModel.from_pretrained("facebook/dino-vits8", torch_dtype=torch.float16, device_map="auto")
|
||||
feature_extractor = self.default_feature_extractor
|
||||
|
||||
image = prepare_img()
|
||||
inputs = feature_extractor(images=image, return_tensors="pt")
|
||||
pixel_values = inputs.pixel_values.to(torch_device)
|
||||
|
||||
# forward pass to make sure inference works in fp16
|
||||
with torch.no_grad():
|
||||
_ = model(pixel_values)
|
||||
|
Loading…
Reference in New Issue
Block a user