mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Specifying torch dtype in Qwen2VLForConditionalGeneration (#33953)
* Specifying torch dtype * Reverting change & changing fallback _from_config() dtype
This commit is contained in:
parent
f8a260e2a4
commit
dda3f91d06
@ -1499,7 +1499,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
torch_dtype (`torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype.
|
||||
"""
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
|
||||
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
|
||||
|
||||
# override default dtype if needed
|
||||
|
Loading…
Reference in New Issue
Block a user