ruff format

This commit is contained in:
Louis Brulé Naudet 2025-03-22 20:26:52 +01:00
parent b18353e343
commit 451df638e5

View File

@ -207,6 +207,7 @@ def is_dvclive_available():
def is_swanlab_available():
return importlib.util.find_spec("swanlab") is not None
def is_logfire_available() -> bool:
return importlib.util.find_spec("logfire") is not None
@ -2320,11 +2321,10 @@ class LogfireCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [Logfire](https://pydantic.dev/logfire).
"""
def __init__(self) -> None:
if not is_logfire_available():
raise RuntimeError(
"LogfireCallback requires `logfire` to be installed. Run `pip install logfire`."
)
raise RuntimeError("LogfireCallback requires `logfire` to be installed. Run `pip install logfire`.")
import logfire
@ -2332,16 +2332,18 @@ class LogfireCallback(TrainerCallback):
self._logfire_token = os.getenv("LOGFIRE_TOKEN", None)
self._initialized = False
self._logfire.configure(
token=self._logfire_token, console=False, inspect_arguments=False
)
self._logfire.configure(token=self._logfire_token, console=False, inspect_arguments=False)
def on_train_begin(self, args, state, control, model=None, **kwargs) -> None:
if self._logfire and state.is_local_process_zero:
def make_serializable(obj) -> object:
if hasattr(obj, '__dict__'):
return {k: make_serializable(v) for k, v in obj.__dict__.items()
if not k.startswith('_') and not callable(v)}
if hasattr(obj, "__dict__"):
return {
k: make_serializable(v)
for k, v in obj.__dict__.items()
if not k.startswith("_") and not callable(v)
}
elif isinstance(obj, list | tuple):
return [make_serializable(x) for x in obj]
elif isinstance(obj, dict):