Improve notrainer examples (#17449)

* improve no-trainer examples

* Trigger CI

* adding comment to clarify tracker init on main process

* Trigger CI

* Trigger CI

* Trigger CI
This commit is contained in:
Sourab Mangrulkar 2022-05-28 00:06:31 +05:30 committed by GitHub
parent 7999ec125f
commit d156898f3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 310 additions and 118 deletions

View File

@ -163,7 +163,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument(
"--ignore_mismatched_sizes",
@ -192,8 +202,11 @@ def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
logger.info(accelerator.state)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
@ -384,12 +397,15 @@ def main():
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
# We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log`
# only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("image_classification_no_trainer", experiment_config)
if accelerator.is_main_process:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("image_classification_no_trainer", experiment_config)
# Get the metric function
metric = load_metric("accuracy")
@ -506,10 +522,11 @@ def main():
accelerator.log(
{
"accuracy": eval_metric,
"train_loss": total_loss,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:

View File

@ -45,7 +45,6 @@ from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
@ -94,7 +93,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
required=False,
)
parser.add_argument(
"--config_name",
@ -206,7 +205,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
args = parser.parse_args()
@ -231,8 +240,11 @@ def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -451,7 +463,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
@ -488,12 +500,15 @@ def main():
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
# We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log`
# only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("clm_no_trainer", experiment_config)
if accelerator.is_main_process:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("clm_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@ -577,15 +592,23 @@ def main():
losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
try:
perplexity = math.exp(torch.mean(losses))
eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
logger.info(f"epoch {epoch}: perplexity: {perplexity}")
logger.info(f"epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}")
if args.with_tracking:
accelerator.log(
{"perplexity": perplexity, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
{
"perplexity": perplexity,
"eval_loss": eval_loss,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:

View File

@ -45,7 +45,6 @@ from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForMaskedLM,
AutoTokenizer,
@ -97,7 +96,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
required=False,
)
parser.add_argument(
"--config_name",
@ -213,7 +212,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
args = parser.parse_args()
@ -240,8 +249,11 @@ def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -492,7 +504,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# On TPU, the tie weights in our model have been disconnected, so we need to restore the ties.
if accelerator.distributed_type == DistributedType.TPU:
@ -532,12 +544,15 @@ def main():
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
# We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log`
# only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("mlm_no_trainer", experiment_config)
if accelerator.is_main_process:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("mlm_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@ -622,7 +637,8 @@ def main():
losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
try:
perplexity = math.exp(torch.mean(losses))
eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss)
except OverflowError:
perplexity = float("inf")
@ -630,7 +646,14 @@ def main():
if args.with_tracking:
accelerator.log(
{"perplexity": perplexity, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
{
"perplexity": perplexity,
"eval_loss": eval_loss,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:

View File

@ -43,7 +43,6 @@ from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForMultipleChoice,
AutoTokenizer,
@ -99,7 +98,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
required=False,
)
parser.add_argument(
"--config_name",
@ -194,7 +193,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
args = parser.parse_args()
@ -265,8 +274,11 @@ def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -447,7 +459,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Use the device given by the `accelerator` object.
device = accelerator.device
@ -484,12 +496,15 @@ def main():
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
# We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log`
# only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("swag_no_trainer", experiment_config)
if accelerator.is_main_process:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("swag_no_trainer", experiment_config)
# Metrics
metric = load_metric("accuracy")
@ -589,7 +604,13 @@ def main():
if args.with_tracking:
accelerator.log(
{"accuracy": eval_metric, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
{
"accuracy": eval_metric,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:

View File

@ -41,7 +41,6 @@ from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForQuestionAnswering,
AutoTokenizer,
@ -135,7 +134,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
required=False,
)
parser.add_argument(
"--config_name",
@ -288,7 +287,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
args = parser.parse_args()
@ -321,8 +330,11 @@ def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -728,7 +740,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@ -761,12 +773,15 @@ def main():
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
# We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log`
# only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("qa_no_trainer", experiment_config)
if accelerator.is_main_process:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("qa_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@ -937,14 +952,14 @@ def main():
if args.with_tracking:
log = {
"squad_v2" if args.version_2_with_negative else "squad": eval_metric,
"train_loss": total_loss,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
}
if args.do_predict:
log["squad_v2_predict" if args.version_2_with_negative else "squad_predict"] = predict_metric
accelerator.log(log)
accelerator.log(log, step=completed_steps)
if args.output_dir is not None:
accelerator.wait_for_everyone()

View File

@ -285,7 +285,17 @@ def parse_args():
"--with_tracking",
required=False,
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
args = parser.parse_args()
@ -306,8 +316,11 @@ def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
@ -482,11 +495,15 @@ def main():
# Instantiate metric
metric = load_metric("mean_iou")
# We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log`
# only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("semantic_segmentation_no_trainer", experiment_config)
if accelerator.is_main_process:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("semantic_segmentation_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@ -615,10 +632,11 @@ def main():
"mean_iou": eval_metrics["mean_iou"],
"mean_accuracy": eval_metrics["mean_accuracy"],
"overall_accuracy": eval_metrics["overall_accuracy"],
"train_loss": total_loss,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:

View File

@ -43,7 +43,6 @@ from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
@ -185,7 +184,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
required=False,
)
parser.add_argument(
"--config_name",
@ -287,7 +286,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
args = parser.parse_args()
@ -311,8 +320,11 @@ def parse_args():
def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
if args.source_prefix is None and args.model_name_or_path in [
"t5-small",
"t5-base",
@ -521,7 +533,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@ -554,12 +566,15 @@ def main():
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
# We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log`
# only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("summarization_no_trainer", experiment_config)
if accelerator.is_main_process:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("summarization_no_trainer", experiment_config)
# Metric
metric = load_metric("rouge")
@ -693,10 +708,10 @@ def main():
logger.info(result)
if args.with_tracking:
result["train_loss"] = total_loss
result["train_loss"] = total_loss.item() / len(train_dataloader)
result["epoch"] = epoch
result["step"] = completed_steps
accelerator.log(result)
accelerator.log(result, step=completed_steps)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()

View File

@ -33,7 +33,6 @@ from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository
from transformers import (
AdamW,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
@ -168,7 +167,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument(
"--ignore_mismatched_sizes",
@ -198,8 +207,11 @@ def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -403,7 +415,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@ -436,12 +448,15 @@ def main():
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
# We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log`
# only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("glue_no_trainer", experiment_config)
if accelerator.is_main_process:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("glue_no_trainer", experiment_config)
# Get the metric function
if args.task_name is not None:
@ -545,10 +560,11 @@ def main():
accelerator.log(
{
"accuracy" if args.task_name is not None else "glue": eval_metric,
"train_loss": total_loss,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:

View File

@ -40,7 +40,6 @@ from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForTokenClassification,
AutoTokenizer,
@ -114,7 +113,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
required=False,
)
parser.add_argument(
"--config_name",
@ -221,7 +220,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
parser.add_argument(
"--ignore_mismatched_sizes",
@ -251,8 +260,11 @@ def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@ -513,7 +525,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Use the device given by the `accelerator` object.
device = accelerator.device
@ -550,12 +562,15 @@ def main():
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
# We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log`
# only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("ner_no_trainer", experiment_config)
if accelerator.is_main_process:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("ner_no_trainer", experiment_config)
# Metrics
metric = load_metric("seqeval")
@ -698,7 +713,13 @@ def main():
accelerator.print(f"epoch {epoch}:", eval_metric)
if args.with_tracking:
accelerator.log(
{"seqeval": eval_metric, "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
{
"seqeval": eval_metric,
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
@ -731,7 +752,9 @@ def main():
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_accuracy": eval_metric["accuracy"], "train_loss": float(loss.cpu().detach().numpy())}, f)
json.dump(
{"eval_accuracy": eval_metric["accuracy"], "train_loss": total_loss.item() / len(train_dataloader)}, f
)
if __name__ == "__main__":

View File

@ -41,7 +41,6 @@ from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
AdamW,
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
@ -180,7 +179,7 @@ def parse_args():
"--model_name_or_path",
type=str,
help="Path to pretrained model or model identifier from huggingface.co/models.",
required=True,
required=False,
)
parser.add_argument(
"--config_name",
@ -270,7 +269,17 @@ def parse_args():
parser.add_argument(
"--with_tracking",
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
help="Whether to enable experiment trackers for logging.",
)
parser.add_argument(
"--report_to",
type=str,
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
args = parser.parse_args()
@ -297,8 +306,11 @@ def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
# If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
# in the environment
accelerator = (
Accelerator(log_with=args.report_to, logging_dir=args.output_dir) if args.with_tracking else Accelerator()
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
@ -502,7 +514,7 @@ def main():
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@ -535,12 +547,15 @@ def main():
else:
checkpointing_steps = None
# We need to initialize the trackers we use, and also store our configuration
# We need to initialize the trackers we use, and also store our configuration.
# We initialize the trackers only on main process because `accelerator.log`
# only logs on main process and we don't want empty logs/runs on other processes.
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("translation_no_trainer", experiment_config)
if accelerator.is_main_process:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("translation_no_trainer", experiment_config)
metric = load_metric("sacrebleu")
@ -673,7 +688,13 @@ def main():
if args.with_tracking:
accelerator.log(
{"blue": eval_metric["score"], "train_loss": total_loss, "epoch": epoch, "step": completed_steps},
{
"blue": eval_metric["score"],
"train_loss": total_loss.item() / len(train_dataloader),
"epoch": epoch,
"step": completed_steps,
},
step=completed_steps,
)
if args.push_to_hub and epoch < args.num_train_epochs - 1: