mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Specifying torch dtype in Qwen2VLForConditionalGeneration (#33953)
* Specifying torch dtype * Reverting change & changing fallback _from_config() dtype
This commit is contained in:
parent
f8a260e2a4
commit
dda3f91d06
@ -1499,7 +1499,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
torch_dtype (`torch.dtype`, *optional*):
|
torch_dtype (`torch.dtype`, *optional*):
|
||||||
Override the default `torch.dtype` and load the model under this dtype.
|
Override the default `torch.dtype` and load the model under this dtype.
|
||||||
"""
|
"""
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
|
||||||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
||||||
|
|
||||||
# override default dtype if needed
|
# override default dtype if needed
|
||||||
|
Loading…
Reference in New Issue
Block a user