mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Fix import torchao.prototype.low_bit_optim since torchao v0.11 (#38174)
* Fix ModuleNotFoundError torchao.prototype.low_bit_optim since torchao v 0.11.0 * Fix space on blank line * update torchao's AdamW4bit and AdamW8bit import for v0.11.0 * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
0ba95564b7
commit
a4389494c7
@ -1679,8 +1679,11 @@ class Trainer:
|
||||
"You need to have `torch>2.4` in order to use torch 4-bit optimizers. "
|
||||
"Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly."
|
||||
)
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
||||
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.11.0"):
|
||||
# https://github.com/pytorch/ao/pull/2159
|
||||
from torchao.optim import AdamW4bit, AdamW8bit
|
||||
else:
|
||||
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
|
||||
if args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
|
||||
optimizer_cls = AdamW4bit
|
||||
elif args.optim == OptimizerNames.ADAMW_TORCH_8BIT:
|
||||
|
Loading…
Reference in New Issue
Block a user