[trainer] param count for deepspeed zero3 (#22193)

[trainer] param count for zero3
This commit is contained in:
Stas Bekman 2023-03-17 11:02:55 -07:00 committed by GitHub
parent cf601b902f
commit 60d51ef512
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 3 deletions

View File

@ -97,6 +97,7 @@ from .trainer_pt_utils import (
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_model_param_count,
get_module_class_from_name,
get_parameter_names,
nested_concat,
@ -1744,9 +1745,7 @@ class Trainer:
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
logger.info(
f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
)
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True)}")
self.state.epoch = 0
start_time = time.time()

View File

@ -35,6 +35,7 @@ from torch import nn
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler
from .deepspeed import is_deepspeed_zero3_enabled
from .tokenization_utils_base import BatchEncoding
from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging
@ -1032,6 +1033,23 @@ def save_state(self):
self.state.save_to_json(path)
def get_model_param_count(model, trainable_only=False):
"""
Calculate model's total param count. If trainable_only is True then count only those requiring grads
"""
if is_deepspeed_zero3_enabled():
def numel(p):
return p.ds_numel
else:
def numel(p):
return p.numel()
return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
def get_parameter_names(model, forbidden_layer_types):
"""
Returns the names of the model parameters that are not inside a forbidden layer.