Fix import torchao.prototype.low_bit_optim since torchao v0.11 (#38174)

* Fix ModuleNotFoundError torchao.prototype.low_bit_optim since torchao v 0.11.0

* Fix space on blank line

* update torchao's AdamW4bit and AdamW8bit import for v0.11.0

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Chachura Baptiste 2025-05-16 18:02:33 +02:00 committed by GitHub
parent 0ba95564b7
commit a4389494c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1679,8 +1679,11 @@ class Trainer:
"You need to have `torch>2.4` in order to use torch 4-bit optimizers. "
"Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly."
)
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.11.0"):
# https://github.com/pytorch/ao/pull/2159
from torchao.optim import AdamW4bit, AdamW8bit
else:
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
if args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
optimizer_cls = AdamW4bit
elif args.optim == OptimizerNames.ADAMW_TORCH_8BIT: