Sagemaker Model Parallel tensoboard writing fix (#10403)

* Added tb fix

* Removed local rank condition

* Updated reference to args
This commit is contained in:
Mansi Mane 2021-02-26 05:04:55 -08:00 committed by GitHub
parent 83d2d55c94
commit 7fc686efb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -71,11 +71,21 @@ if is_smdistributed_available():
class SageMakerTrainer(Trainer):
def __init__(self, args=None, **kwargs):
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
super().__init__(args=args, **kwargs)
self.is_model_parallel_enabled = is_smdistributed_available() and self.args.mp_parameters != ""
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
def is_world_process_zero(self) -> bool:
"""
Whether or not this process is the global main process (when training in a distributed fashion on several
machines, this is only going to be :obj:`True` for one process).
"""
if self.is_model_parallel_enabled:
return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0
else:
return super.is_world_process_zero()
def _get_train_sampler(self):
if self.is_model_parallel_enabled:
if self.args.group_by_length: