Specifying torch dtype in Qwen2VLForConditionalGeneration (#33953)

* Specifying torch dtype

* Reverting change & changing fallback _from_config() dtype
This commit is contained in:
Hamza Tahboub 2024-10-10 05:39:33 -07:00 committed by GitHub
parent f8a260e2a4
commit dda3f91d06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1499,7 +1499,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype.
"""
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
# override default dtype if needed