Fix accelerate logger bug (#23650)

* fix logger bug

* Update tests/mixed_int8/test_mixed_int8.py

Co-authored-by: Zachary Mueller <muellerzr@gmail.com>

* import `PartialState`

---------

Co-authored-by: Zachary Mueller <muellerzr@gmail.com>
This commit is contained in:
Younes Belkada 2023-05-22 15:39:47 +02:00 committed by GitHub
parent 29294b0e68
commit 7bbdfd7b24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -29,6 +29,7 @@ from transformers import (
pipeline,
)
from transformers.testing_utils import (
is_accelerate_available,
is_torch_available,
require_accelerate,
require_bitsandbytes,
@ -40,6 +41,13 @@ from transformers.testing_utils import (
from transformers.utils.versions import importlib_metadata
if is_accelerate_available():
from accelerate import PartialState
from accelerate.logging import get_logger
logger = get_logger(__name__)
_ = PartialState()
if is_torch_available():
import torch
import torch.nn as nn