diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 11acf947bb6..b741347d619 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -398,11 +398,16 @@ class DeiTPreTrainedModel(PreTrainedModel): base_model_prefix = "deit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = [] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range) + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): @@ -511,6 +516,11 @@ class DeiTModel(DeiTPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) encoder_outputs = self.encoder( diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index c7bf682cc87..2bf59866efb 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -68,14 +68,18 @@ class ViTEmbeddings(nn.Module): super().__init__() self.cls_token = nn.Parameter( - nn.init.trunc_normal_(torch.zeros(1, 1, config.hidden_size), mean=0.0, std=config.initializer_range) + nn.init.trunc_normal_( + torch.zeros(1, 1, config.hidden_size, dtype=torch.float32), mean=0.0, std=config.initializer_range + ) ) self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.patch_embeddings = ViTPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches self.position_embeddings = nn.Parameter( nn.init.trunc_normal_( - torch.zeros(1, num_patches + 1, config.hidden_size), mean=0.0, std=config.initializer_range + torch.zeros(1, num_patches + 1, config.hidden_size, dtype=torch.float32), + mean=0.0, + std=config.initializer_range, ) ) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -442,11 +446,16 @@ class ViTPreTrainedModel(PreTrainedModel): base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = [] def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=self.config.initializer_range) + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): @@ -558,6 +567,11 @@ class ViTModel(ViTPreTrainedModel): # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + embedding_output = self.embeddings( pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding ) diff --git a/tests/models/deit/test_modeling_deit.py b/tests/models/deit/test_modeling_deit.py index 82b7f286925..19858cb5b7f 100644 --- a/tests/models/deit/test_modeling_deit.py +++ b/tests/models/deit/test_modeling_deit.py @@ -21,7 +21,14 @@ import warnings from transformers import DeiTConfig from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import ( + require_accelerate, + require_torch, + require_torch_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -394,3 +401,23 @@ class DeiTModelIntegrationTest(unittest.TestCase): expected_slice = torch.tensor([-1.0266, 0.1912, -1.2861]).to(torch_device) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) + + @slow + @require_accelerate + @require_torch_gpu + def test_inference_fp16(self): + r""" + A small test to make sure that inference work in half precision without any problem. + """ + model = DeiTModel.from_pretrained( + "facebook/deit-base-distilled-patch16-224", torch_dtype=torch.float16, device_map="auto" + ) + feature_extractor = self.default_feature_extractor + + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass to make sure inference works in fp16 + with torch.no_grad(): + _ = model(pixel_values) diff --git a/tests/models/vit/test_modeling_vit.py b/tests/models/vit/test_modeling_vit.py index 5f856436f3c..52e09aab774 100644 --- a/tests/models/vit/test_modeling_vit.py +++ b/tests/models/vit/test_modeling_vit.py @@ -19,7 +19,14 @@ import inspect import unittest from transformers import ViTConfig -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import ( + require_accelerate, + require_torch, + require_torch_gpu, + require_vision, + slow, + torch_device, +) from transformers.utils import cached_property, is_torch_available, is_vision_available from ...test_configuration_common import ConfigTester @@ -300,3 +307,21 @@ class ViTModelIntegrationTest(unittest.TestCase): ).to(torch_device) self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4)) + + @slow + @require_accelerate + @require_torch_gpu + def test_inference_fp16(self): + r""" + A small test to make sure that inference work in half precision without any problem. + """ + model = ViTModel.from_pretrained("facebook/dino-vits8", torch_dtype=torch.float16, device_map="auto") + feature_extractor = self.default_feature_extractor + + image = prepare_img() + inputs = feature_extractor(images=image, return_tensors="pt") + pixel_values = inputs.pixel_values.to(torch_device) + + # forward pass to make sure inference works in fp16 + with torch.no_grad(): + _ = model(pixel_values)