mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-25 23:38:59 +06:00
Add accelerate
support for ViT
family (#20174)
* add `accelerate` support for `ViT` family - add `_no_split_modules` - manually cast to the right `dtype`: to change * enable `float16` for `deit` * fix `make fixup` * add `slow` test for `fp16` inference * another safety check * Update src/transformers/models/deit/modeling_deit.py
This commit is contained in:
parent
11b2e45ccc
commit
f1e8c48c5e
@ -398,11 +398,16 @@ class DeiTPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "deit"
|
base_model_prefix = "deit"
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = []
|
||||||
|
|
||||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||||
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
|
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||||
|
# `trunc_normal_cpu` not implemented in `half` issues
|
||||||
|
module.weight.data = nn.init.trunc_normal_(
|
||||||
|
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||||||
|
).to(module.weight.dtype)
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, nn.LayerNorm):
|
||||||
@ -511,6 +516,11 @@ class DeiTModel(DeiTPreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
||||||
|
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
||||||
|
if pixel_values.dtype != expected_dtype:
|
||||||
|
pixel_values = pixel_values.to(expected_dtype)
|
||||||
|
|
||||||
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
|
@ -68,14 +68,18 @@ class ViTEmbeddings(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.cls_token = nn.Parameter(
|
self.cls_token = nn.Parameter(
|
||||||
nn.init.trunc_normal_(torch.zeros(1, 1, config.hidden_size), mean=0.0, std=config.initializer_range)
|
nn.init.trunc_normal_(
|
||||||
|
torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
||||||
self.patch_embeddings = ViTPatchEmbeddings(config)
|
self.patch_embeddings = ViTPatchEmbeddings(config)
|
||||||
num_patches = self.patch_embeddings.num_patches
|
num_patches = self.patch_embeddings.num_patches
|
||||||
self.position_embeddings = nn.Parameter(
|
self.position_embeddings = nn.Parameter(
|
||||||
nn.init.trunc_normal_(
|
nn.init.trunc_normal_(
|
||||||
torch.zeros(1, num_patches + 1, config.hidden_size), mean=0.0, std=config.initializer_range
|
torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32),
|
||||||
|
mean=0.0,
|
||||||
|
std=config.initializer_range,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||||
@ -442,11 +446,16 @@ class ViTPreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "vit"
|
base_model_prefix = "vit"
|
||||||
main_input_name = "pixel_values"
|
main_input_name = "pixel_values"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_no_split_modules = []
|
||||||
|
|
||||||
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
||||||
"""Initialize the weights"""
|
"""Initialize the weights"""
|
||||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||||
module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range)
|
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
||||||
|
# `trunc_normal_cpu` not implemented in `half` issues
|
||||||
|
module.weight.data = nn.init.trunc_normal_(
|
||||||
|
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
||||||
|
).to(module.weight.dtype)
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
elif isinstance(module, nn.LayerNorm):
|
elif isinstance(module, nn.LayerNorm):
|
||||||
@ -558,6 +567,11 @@ class ViTModel(ViTPreTrainedModel):
|
|||||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
||||||
|
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
||||||
|
if pixel_values.dtype != expected_dtype:
|
||||||
|
pixel_values = pixel_values.to(expected_dtype)
|
||||||
|
|
||||||
embedding_output = self.embeddings(
|
embedding_output = self.embeddings(
|
||||||
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
||||||
)
|
)
|
||||||
|
@ -21,7 +21,14 @@ import warnings
|
|||||||
|
|
||||||
from transformers import DeiTConfig
|
from transformers import DeiTConfig
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
require_accelerate,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -394,3 +401,23 @@ class DeiTModelIntegrationTest(unittest.TestCase):
|
|||||||
expected_slice = torch.tensor([-1.0266, 0.1912, -1.2861]).to(torch_device)
|
expected_slice = torch.tensor([-1.0266, 0.1912, -1.2861]).to(torch_device)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_accelerate
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_inference_fp16(self):
|
||||||
|
r"""
|
||||||
|
A small test to make sure that inference work in half precision without any problem.
|
||||||
|
"""
|
||||||
|
model = DeiTModel.from_pretrained(
|
||||||
|
"facebook/deit-base-distilled-patch16-224", torch_dtype=torch.float16, device_map="auto"
|
||||||
|
)
|
||||||
|
feature_extractor = self.default_feature_extractor
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = feature_extractor(images=image, return_tensors="pt")
|
||||||
|
pixel_values = inputs.pixel_values.to(torch_device)
|
||||||
|
|
||||||
|
# forward pass to make sure inference works in fp16
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(pixel_values)
|
||||||
|
@ -19,7 +19,14 @@ import inspect
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import ViTConfig
|
from transformers import ViTConfig
|
||||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
from transformers.testing_utils import (
|
||||||
|
require_accelerate,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
require_vision,
|
||||||
|
slow,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
from transformers.utils import cached_property, is_torch_available, is_vision_available
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@ -300,3 +307,21 @@ class ViTModelIntegrationTest(unittest.TestCase):
|
|||||||
).to(torch_device)
|
).to(torch_device)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
@require_accelerate
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_inference_fp16(self):
|
||||||
|
r"""
|
||||||
|
A small test to make sure that inference work in half precision without any problem.
|
||||||
|
"""
|
||||||
|
model = ViTModel.from_pretrained("facebook/dino-vits8", torch_dtype=torch.float16, device_map="auto")
|
||||||
|
feature_extractor = self.default_feature_extractor
|
||||||
|
|
||||||
|
image = prepare_img()
|
||||||
|
inputs = feature_extractor(images=image, return_tensors="pt")
|
||||||
|
pixel_values = inputs.pixel_values.to(torch_device)
|
||||||
|
|
||||||
|
# forward pass to make sure inference works in fp16
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(pixel_values)
|
||||||
|
Loading…
Reference in New Issue
Block a user