mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enable different torch dtype in sub models (#34873)
* fix * fix test * add tests * add more tests * fix tests * supposed to be a torch.dtype test * handle BC and make fp32 default
This commit is contained in:
parent
87089176d9
commit
84a6789145
@ -994,8 +994,11 @@ class PretrainedConfig(PushToHubMixin):
|
||||
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
|
||||
string, which can then be stored in the json format.
|
||||
"""
|
||||
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
||||
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
||||
if d.get("torch_dtype", None) is not None:
|
||||
if isinstance(d["torch_dtype"], dict):
|
||||
d["torch_dtype"] = {k: str(v).split(".")[-1] for k, v in d["torch_dtype"].items()}
|
||||
elif not isinstance(d["torch_dtype"], str):
|
||||
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
||||
for value in d.values():
|
||||
if isinstance(value, dict):
|
||||
self.dict_torch_dtype_to_str(value)
|
||||
|
@ -1312,11 +1312,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
"`PretrainedConfig`. To create a model from a pretrained model use "
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
# Save config and origin of the pretrained weights if given in model
|
||||
if not getattr(config, "_attn_implementation_autoset", False):
|
||||
config = self._autoset_attn_implementation(
|
||||
config, torch_dtype=torch.get_default_dtype(), check_device_map=False
|
||||
)
|
||||
# config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests
|
||||
dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype()
|
||||
config = self._autoset_attn_implementation(config, torch_dtype=dtype, check_device_map=False)
|
||||
self.config = config
|
||||
|
||||
# for initialization of the loss
|
||||
@ -1411,7 +1410,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
|
||||
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
|
||||
# modeling code, we can try to infer it here same way as done in `from_pretrained`
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
|
||||
torch_dtype = kwargs.pop("torch_dtype", config.torch_dtype)
|
||||
if isinstance(torch_dtype, str):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
|
||||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
||||
|
||||
# override default dtype if needed
|
||||
@ -4020,11 +4022,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
elif hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
|
||||
)
|
||||
for sub_config_key in config.sub_configs.keys():
|
||||
sub_config = getattr(config, sub_config_key)
|
||||
sub_config.torch_dtype = torch_dtype
|
||||
elif isinstance(torch_dtype, torch.dtype):
|
||||
pass
|
||||
elif isinstance(torch_dtype, dict):
|
||||
for key, curr_dtype in torch_dtype.items():
|
||||
if hasattr(config, key):
|
||||
value = getattr(config, key)
|
||||
value.torch_dtype = curr_dtype
|
||||
# main torch dtype for modules that aren't part of any sub-config
|
||||
torch_dtype = torch_dtype.get("")
|
||||
config.torch_dtype = torch_dtype
|
||||
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
elif torch_dtype is None:
|
||||
torch_dtype = torch.float32
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` "
|
||||
f"for each sub-config in composite configs, but received {torch_dtype}"
|
||||
)
|
||||
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
else:
|
||||
# set fp32 as the default dtype for BC
|
||||
default_dtype = str(torch.get_default_dtype()).split(".")[-1]
|
||||
config.torch_dtype = default_dtype
|
||||
for key in config.sub_configs.keys():
|
||||
value = getattr(config, key)
|
||||
value.torch_dtype = default_dtype
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
|
@ -967,62 +967,6 @@ class ChameleonVQVAEEncoder(nn.Module):
|
||||
return last_hidden_state
|
||||
|
||||
|
||||
CHAMELEON_VQ_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`ChameleonVQVAEConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
|
||||
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
|
||||
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
|
||||
""",
|
||||
CHAMELEON_VQ_START_DOCSTRING,
|
||||
)
|
||||
class ChameleonVQVAE(PreTrainedModel):
|
||||
config_class = ChameleonVQVAEConfig
|
||||
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.GroupNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def __init__(self, config: ChameleonVQVAEConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.encoder = ChameleonVQVAEEncoder(config)
|
||||
self.quantize = ChameleonVQVAEVectorQuantizer(config)
|
||||
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
|
||||
self.eval() # Chameleon's VQ model is frozen
|
||||
|
||||
def encode(self, pixel_values: torch.LongTensor):
|
||||
hidden_states = self.encoder(pixel_values)
|
||||
hidden_states = self.quant_conv(hidden_states)
|
||||
quant, emb_loss, indices = self.quantize(hidden_states)
|
||||
return quant, emb_loss, indices
|
||||
|
||||
|
||||
class ChameleonImageVocabularyMapping:
|
||||
"""
|
||||
A class for mapping discrete image tokens from VQGAN to BPE tokens.
|
||||
@ -1118,6 +1062,62 @@ class ChameleonPreTrainedModel(PreTrainedModel):
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
CHAMELEON_VQ_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`ChameleonVQVAEConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
|
||||
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
|
||||
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
|
||||
""",
|
||||
CHAMELEON_VQ_START_DOCSTRING,
|
||||
)
|
||||
class ChameleonVQVAE(ChameleonPreTrainedModel):
|
||||
config_class = ChameleonVQVAEConfig
|
||||
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
elif isinstance(module, nn.GroupNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
elif isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def __init__(self, config: ChameleonVQVAEConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.encoder = ChameleonVQVAEEncoder(config)
|
||||
self.quantize = ChameleonVQVAEVectorQuantizer(config)
|
||||
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
|
||||
self.eval() # Chameleon's VQ model is frozen
|
||||
|
||||
def encode(self, pixel_values: torch.LongTensor):
|
||||
hidden_states = self.encoder(pixel_values)
|
||||
hidden_states = self.quant_conv(hidden_states)
|
||||
quant, emb_loss, indices = self.quantize(hidden_states)
|
||||
return quant, emb_loss, indices
|
||||
|
||||
|
||||
CHAMELEON_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
@ -1211,7 +1211,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
|
||||
[decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.vqmodel = ChameleonVQVAE(config.vq_config)
|
||||
self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
|
@ -227,6 +227,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
||||
pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration}
|
||||
test_pruning = False
|
||||
test_head_masking = False
|
||||
_is_composite = True
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = Qwen2VLVisionText2TextModelTester(self)
|
||||
|
@ -37,6 +37,7 @@ from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForImageClassification,
|
||||
AutoModelForSequenceClassification,
|
||||
LlavaForConditionalGeneration,
|
||||
OwlViTForObjectDetection,
|
||||
PretrainedConfig,
|
||||
is_torch_available,
|
||||
@ -300,6 +301,7 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||
TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM"
|
||||
TINY_IMAGE_CLASSIF = "hf-internal-testing/tiny-random-SiglipForImageClassification"
|
||||
TINY_LLAVA = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration"
|
||||
|
||||
LOG = logging.get_logger(__name__)
|
||||
|
||||
@ -460,6 +462,59 @@ class ModelUtilsTest(TestCasePlus):
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
|
||||
|
||||
def test_model_from_config_torch_dtype_composite(self):
|
||||
"""
|
||||
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
|
||||
"""
|
||||
# should be able to set torch_dtype as a simple string and the model loads it correctly
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float32)
|
||||
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16")
|
||||
self.assertEqual(model.language_model.dtype, torch.float16)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
|
||||
# should be able to set torch_dtype as a dict for each sub-config
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
|
||||
)
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
|
||||
|
||||
# should be able to set the values as torch.dtype (not str)
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
|
||||
)
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
|
||||
|
||||
# should be able to set the values in configs directly and pass it to `from_pretrained`
|
||||
config = copy.deepcopy(model.config)
|
||||
config.text_config.torch_dtype = torch.float32
|
||||
config.vision_config.torch_dtype = torch.bfloat16
|
||||
config.torch_dtype = torch.float16
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
|
||||
|
||||
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
|
||||
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
|
||||
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
|
||||
|
||||
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||
with self.assertRaises(ValueError):
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64")
|
||||
model = LlavaForConditionalGeneration.from_pretrained(
|
||||
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_model_from_pretrained_meta_device(self):
|
||||
def is_on_meta(model_id, dtype):
|
||||
|
Loading…
Reference in New Issue
Block a user