Enable different torch dtype in sub models (#34873)

* fix

* fix test

* add tests

* add more tests

* fix tests

* supposed to be a torch.dtype test

* handle BC and make fp32 default
This commit is contained in:
Raushan Turganbay 2025-01-13 13:42:08 +01:00 committed by GitHub
parent 87089176d9
commit 84a6789145
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 155 additions and 68 deletions

View File

@ -994,8 +994,11 @@ class PretrainedConfig(PushToHubMixin):
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
string, which can then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
if d.get("torch_dtype", None) is not None:
if isinstance(d["torch_dtype"], dict):
d["torch_dtype"] = {k: str(v).split(".")[-1] for k, v in d["torch_dtype"].items()}
elif not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)

View File

@ -1312,11 +1312,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"`PretrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
# Save config and origin of the pretrained weights if given in model
if not getattr(config, "_attn_implementation_autoset", False):
config = self._autoset_attn_implementation(
config, torch_dtype=torch.get_default_dtype(), check_device_map=False
)
# config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests
dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype()
config = self._autoset_attn_implementation(config, torch_dtype=dtype, check_device_map=False)
self.config = config
# for initialization of the loss
@ -1411,7 +1410,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
# modeling code, we can try to infer it here same way as done in `from_pretrained`
torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
torch_dtype = kwargs.pop("torch_dtype", config.torch_dtype)
if isinstance(torch_dtype, str):
torch_dtype = getattr(torch, torch_dtype)
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
# override default dtype if needed
@ -4020,11 +4022,37 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
elif hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
else:
raise ValueError(
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
)
for sub_config_key in config.sub_configs.keys():
sub_config = getattr(config, sub_config_key)
sub_config.torch_dtype = torch_dtype
elif isinstance(torch_dtype, torch.dtype):
pass
elif isinstance(torch_dtype, dict):
for key, curr_dtype in torch_dtype.items():
if hasattr(config, key):
value = getattr(config, key)
value.torch_dtype = curr_dtype
# main torch dtype for modules that aren't part of any sub-config
torch_dtype = torch_dtype.get("")
config.torch_dtype = torch_dtype
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
torch_dtype = getattr(torch, torch_dtype)
elif torch_dtype is None:
torch_dtype = torch.float32
else:
raise ValueError(
f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` "
f"for each sub-config in composite configs, but received {torch_dtype}"
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
else:
# set fp32 as the default dtype for BC
default_dtype = str(torch.get_default_dtype()).split(".")[-1]
config.torch_dtype = default_dtype
for key in config.sub_configs.keys():
value = getattr(config, key)
value.torch_dtype = default_dtype
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (

View File

@ -967,62 +967,6 @@ class ChameleonVQVAEEncoder(nn.Module):
return last_hidden_state
CHAMELEON_VQ_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`ChameleonVQVAEConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"""The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
""",
CHAMELEON_VQ_START_DOCSTRING,
)
class ChameleonVQVAE(PreTrainedModel):
config_class = ChameleonVQVAEConfig
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.GroupNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__(config)
self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
def encode(self, pixel_values: torch.LongTensor):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states)
return quant, emb_loss, indices
class ChameleonImageVocabularyMapping:
"""
A class for mapping discrete image tokens from VQGAN to BPE tokens.
@ -1118,6 +1062,62 @@ class ChameleonPreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_()
CHAMELEON_VQ_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`ChameleonVQVAEConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"""The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
""",
CHAMELEON_VQ_START_DOCSTRING,
)
class ChameleonVQVAE(ChameleonPreTrainedModel):
config_class = ChameleonVQVAEConfig
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.GroupNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
def __init__(self, config: ChameleonVQVAEConfig):
super().__init__(config)
self.encoder = ChameleonVQVAEEncoder(config)
self.quantize = ChameleonVQVAEVectorQuantizer(config)
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
self.eval() # Chameleon's VQ model is frozen
def encode(self, pixel_values: torch.LongTensor):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states)
return quant, emb_loss, indices
CHAMELEON_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@ -1211,7 +1211,7 @@ class ChameleonModel(ChameleonPreTrainedModel):
[decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.vqmodel = ChameleonVQVAE(config.vq_config)
self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing

View File

@ -227,6 +227,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration}
test_pruning = False
test_head_masking = False
_is_composite = True
def setUp(self):
self.model_tester = Qwen2VLVisionText2TextModelTester(self)

View File

@ -37,6 +37,7 @@ from transformers import (
AutoModel,
AutoModelForImageClassification,
AutoModelForSequenceClassification,
LlavaForConditionalGeneration,
OwlViTForObjectDetection,
PretrainedConfig,
is_torch_available,
@ -300,6 +301,7 @@ TINY_T5 = "patrickvonplaten/t5-tiny-random"
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM"
TINY_IMAGE_CLASSIF = "hf-internal-testing/tiny-random-SiglipForImageClassification"
TINY_LLAVA = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration"
LOG = logging.get_logger(__name__)
@ -460,6 +462,59 @@ class ModelUtilsTest(TestCasePlus):
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
def test_model_from_config_torch_dtype_composite(self):
"""
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
"""
# should be able to set torch_dtype as a simple string and the model loads it correctly
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float32)
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16")
self.assertEqual(model.language_model.dtype, torch.float16)
self.assertEqual(model.vision_tower.dtype, torch.float16)
# should be able to set torch_dtype as a dict for each sub-config
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
)
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
# should be able to set the values as torch.dtype (not str)
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
)
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
# should be able to set the values in configs directly and pass it to `from_pretrained`
config = copy.deepcopy(model.config)
config.text_config.torch_dtype = torch.float32
config.vision_config.torch_dtype = torch.bfloat16
config.torch_dtype = torch.float16
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
with self.assertRaises(ValueError):
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64")
model = LlavaForConditionalGeneration.from_pretrained(
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
)
@require_torch
def test_model_from_pretrained_meta_device(self):
def is_on_meta(model_id, dtype):