diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 432b6134400..33669d3f518 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -207,6 +207,7 @@ def is_dvclive_available(): def is_swanlab_available(): return importlib.util.find_spec("swanlab") is not None + def is_logfire_available() -> bool: return importlib.util.find_spec("logfire") is not None @@ -2320,11 +2321,10 @@ class LogfireCallback(TrainerCallback): """ A [`TrainerCallback`] that sends the logs to [Logfire](https://pydantic.dev/logfire). """ + def __init__(self) -> None: if not is_logfire_available(): - raise RuntimeError( - "LogfireCallback requires `logfire` to be installed. Run `pip install logfire`." - ) + raise RuntimeError("LogfireCallback requires `logfire` to be installed. Run `pip install logfire`.") import logfire @@ -2332,16 +2332,18 @@ class LogfireCallback(TrainerCallback): self._logfire_token = os.getenv("LOGFIRE_TOKEN", None) self._initialized = False - self._logfire.configure( - token=self._logfire_token, console=False, inspect_arguments=False - ) + self._logfire.configure(token=self._logfire_token, console=False, inspect_arguments=False) def on_train_begin(self, args, state, control, model=None, **kwargs) -> None: if self._logfire and state.is_local_process_zero: + def make_serializable(obj) -> object: - if hasattr(obj, '__dict__'): - return {k: make_serializable(v) for k, v in obj.__dict__.items() - if not k.startswith('_') and not callable(v)} + if hasattr(obj, "__dict__"): + return { + k: make_serializable(v) + for k, v in obj.__dict__.items() + if not k.startswith("_") and not callable(v) + } elif isinstance(obj, list | tuple): return [make_serializable(x) for x in obj] elif isinstance(obj, dict):