ruff format

This commit is contained in:
Louis Brulé Naudet 2025-03-22 20:26:52 +01:00
parent b18353e343
commit 451df638e5

View File

@ -207,6 +207,7 @@ def is_dvclive_available():
def is_swanlab_available(): def is_swanlab_available():
return importlib.util.find_spec("swanlab") is not None return importlib.util.find_spec("swanlab") is not None
def is_logfire_available() -> bool: def is_logfire_available() -> bool:
return importlib.util.find_spec("logfire") is not None return importlib.util.find_spec("logfire") is not None
@ -2320,11 +2321,10 @@ class LogfireCallback(TrainerCallback):
""" """
A [`TrainerCallback`] that sends the logs to [Logfire](https://pydantic.dev/logfire). A [`TrainerCallback`] that sends the logs to [Logfire](https://pydantic.dev/logfire).
""" """
def __init__(self) -> None: def __init__(self) -> None:
if not is_logfire_available(): if not is_logfire_available():
raise RuntimeError( raise RuntimeError("LogfireCallback requires `logfire` to be installed. Run `pip install logfire`.")
"LogfireCallback requires `logfire` to be installed. Run `pip install logfire`."
)
import logfire import logfire
@ -2332,16 +2332,18 @@ class LogfireCallback(TrainerCallback):
self._logfire_token = os.getenv("LOGFIRE_TOKEN", None) self._logfire_token = os.getenv("LOGFIRE_TOKEN", None)
self._initialized = False self._initialized = False
self._logfire.configure( self._logfire.configure(token=self._logfire_token, console=False, inspect_arguments=False)
token=self._logfire_token, console=False, inspect_arguments=False
)
def on_train_begin(self, args, state, control, model=None, **kwargs) -> None: def on_train_begin(self, args, state, control, model=None, **kwargs) -> None:
if self._logfire and state.is_local_process_zero: if self._logfire and state.is_local_process_zero:
def make_serializable(obj) -> object: def make_serializable(obj) -> object:
if hasattr(obj, '__dict__'): if hasattr(obj, "__dict__"):
return {k: make_serializable(v) for k, v in obj.__dict__.items() return {
if not k.startswith('_') and not callable(v)} k: make_serializable(v)
for k, v in obj.__dict__.items()
if not k.startswith("_") and not callable(v)
}
elif isinstance(obj, list | tuple): elif isinstance(obj, list | tuple):
return [make_serializable(x) for x in obj] return [make_serializable(x) for x in obj]
elif isinstance(obj, dict): elif isinstance(obj, dict):