mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[trainer] param count for deepspeed zero3 (#22193)
[trainer] param count for zero3
This commit is contained in:
parent
cf601b902f
commit
60d51ef512
@ -97,6 +97,7 @@ from .trainer_pt_utils import (
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
find_batch_size,
|
||||
get_model_param_count,
|
||||
get_module_class_from_name,
|
||||
get_parameter_names,
|
||||
nested_concat,
|
||||
@ -1744,9 +1745,7 @@ class Trainer:
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
logger.info(
|
||||
f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
|
||||
)
|
||||
logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True)}")
|
||||
|
||||
self.state.epoch = 0
|
||||
start_time = time.time()
|
||||
|
@ -35,6 +35,7 @@ from torch import nn
|
||||
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from .deepspeed import is_deepspeed_zero3_enabled
|
||||
from .tokenization_utils_base import BatchEncoding
|
||||
from .utils import is_sagemaker_mp_enabled, is_torch_tpu_available, is_training_run_on_sagemaker, logging
|
||||
|
||||
@ -1032,6 +1033,23 @@ def save_state(self):
|
||||
self.state.save_to_json(path)
|
||||
|
||||
|
||||
def get_model_param_count(model, trainable_only=False):
|
||||
"""
|
||||
Calculate model's total param count. If trainable_only is True then count only those requiring grads
|
||||
"""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
|
||||
def numel(p):
|
||||
return p.ds_numel
|
||||
|
||||
else:
|
||||
|
||||
def numel(p):
|
||||
return p.numel()
|
||||
|
||||
return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
|
||||
|
||||
|
||||
def get_parameter_names(model, forbidden_layer_types):
|
||||
"""
|
||||
Returns the names of the model parameters that are not inside a forbidden layer.
|
||||
|
Loading…
Reference in New Issue
Block a user