mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-12 17:20:03 +06:00
ruff format
This commit is contained in:
parent
b18353e343
commit
451df638e5
@ -207,6 +207,7 @@ def is_dvclive_available():
|
|||||||
def is_swanlab_available():
|
def is_swanlab_available():
|
||||||
return importlib.util.find_spec("swanlab") is not None
|
return importlib.util.find_spec("swanlab") is not None
|
||||||
|
|
||||||
|
|
||||||
def is_logfire_available() -> bool:
|
def is_logfire_available() -> bool:
|
||||||
return importlib.util.find_spec("logfire") is not None
|
return importlib.util.find_spec("logfire") is not None
|
||||||
|
|
||||||
@ -2320,11 +2321,10 @@ class LogfireCallback(TrainerCallback):
|
|||||||
"""
|
"""
|
||||||
A [`TrainerCallback`] that sends the logs to [Logfire](https://pydantic.dev/logfire).
|
A [`TrainerCallback`] that sends the logs to [Logfire](https://pydantic.dev/logfire).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
if not is_logfire_available():
|
if not is_logfire_available():
|
||||||
raise RuntimeError(
|
raise RuntimeError("LogfireCallback requires `logfire` to be installed. Run `pip install logfire`.")
|
||||||
"LogfireCallback requires `logfire` to be installed. Run `pip install logfire`."
|
|
||||||
)
|
|
||||||
|
|
||||||
import logfire
|
import logfire
|
||||||
|
|
||||||
@ -2332,16 +2332,18 @@ class LogfireCallback(TrainerCallback):
|
|||||||
self._logfire_token = os.getenv("LOGFIRE_TOKEN", None)
|
self._logfire_token = os.getenv("LOGFIRE_TOKEN", None)
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
self._logfire.configure(
|
self._logfire.configure(token=self._logfire_token, console=False, inspect_arguments=False)
|
||||||
token=self._logfire_token, console=False, inspect_arguments=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_train_begin(self, args, state, control, model=None, **kwargs) -> None:
|
def on_train_begin(self, args, state, control, model=None, **kwargs) -> None:
|
||||||
if self._logfire and state.is_local_process_zero:
|
if self._logfire and state.is_local_process_zero:
|
||||||
|
|
||||||
def make_serializable(obj) -> object:
|
def make_serializable(obj) -> object:
|
||||||
if hasattr(obj, '__dict__'):
|
if hasattr(obj, "__dict__"):
|
||||||
return {k: make_serializable(v) for k, v in obj.__dict__.items()
|
return {
|
||||||
if not k.startswith('_') and not callable(v)}
|
k: make_serializable(v)
|
||||||
|
for k, v in obj.__dict__.items()
|
||||||
|
if not k.startswith("_") and not callable(v)
|
||||||
|
}
|
||||||
elif isinstance(obj, list | tuple):
|
elif isinstance(obj, list | tuple):
|
||||||
return [make_serializable(x) for x in obj]
|
return [make_serializable(x) for x in obj]
|
||||||
elif isinstance(obj, dict):
|
elif isinstance(obj, dict):
|
||||||
|
Loading…
Reference in New Issue
Block a user