Trainer + wandb quality of life logging tweaks (#6241)

* added `name` argument for wandb logging, also logging model config with trainer arguments

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* added tf, post-review changes

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Teven 2020-08-05 15:05:52 +02:00 committed by GitHub
parent 33966811bd
commit bd0eab351a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 2 deletions

View File

@ -383,7 +383,10 @@ class Trainer:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=self.args.to_sanitized_dict())
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(

View File

@ -215,7 +215,8 @@ class TFTrainer:
return self._setup_wandb()
logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"')
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name)
def prediction_loop(
self,

View File

@ -109,6 +109,8 @@ class TrainingArguments:
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
at the next training step under the keyword argument ``mems``.
run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging.
"""
output_dir: str = field(
@ -222,6 +224,10 @@ class TrainingArguments:
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
)
run_name: Optional[str] = field(
default=None, metadata={"help": "An optional descriptor for the run. Notably used for wandb logging."}
)
@property
def train_batch_size(self) -> int:
"""

View File

@ -95,6 +95,8 @@ class TFTrainingArguments(TrainingArguments):
at the next training step under the keyword argument ``mems``.
tpu_name (:obj:`str`, `optional`):
The name of the TPU the process is running on.
run_name (:obj:`str`, `optional`):
A descriptor for the run. Notably used for wandb logging.
"""
tpu_name: str = field(