mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 04:40:06 +06:00
fix torch_dtype on awq (#38463)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
871901cb3d
commit
2e889c18e1
@ -92,11 +92,11 @@ class AwqQuantizer(HfQuantizer):
|
||||
if torch_dtype is None:
|
||||
torch_dtype = torch.float16
|
||||
logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.")
|
||||
elif torch_dtype == torch.bfloat16:
|
||||
logger.warning("`torch.bfloat16` is not supported for AWQ kernels yet. Casting to `torch.float16`.")
|
||||
elif torch_dtype == torch.bfloat16 and torch.cuda.is_available():
|
||||
logger.warning("`torch.bfloat16` is not supported for AWQ CUDA kernels yet. Casting to `torch.float16`.")
|
||||
torch_dtype = torch.float16
|
||||
elif torch_dtype != torch.float16:
|
||||
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
|
||||
elif torch_dtype != torch.float16 and torch.cuda.is_available():
|
||||
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency on CUDA with AWQ.")
|
||||
return torch_dtype
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
|
Loading…
Reference in New Issue
Block a user