diff --git a/src/transformers/sagemaker/trainer_sm.py b/src/transformers/sagemaker/trainer_sm.py index 63b16ab227d..a104ee4426b 100644 --- a/src/transformers/sagemaker/trainer_sm.py +++ b/src/transformers/sagemaker/trainer_sm.py @@ -71,11 +71,21 @@ if is_smdistributed_available(): class SageMakerTrainer(Trainer): def __init__(self, args=None, **kwargs): + self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != "" super().__init__(args=args, **kwargs) - self.is_model_parallel_enabled = is_smdistributed_available() and self.args.mp_parameters != "" if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1: raise ValueError("Gradient accumulation is not supported when model parallel is enabled.") + def is_world_process_zero(self) -> bool: + """ + Whether or not this process is the global main process (when training in a distributed fashion on several + machines, this is only going to be :obj:`True` for one process). + """ + if self.is_model_parallel_enabled: + return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0 + else: + return super.is_world_process_zero() + def _get_train_sampler(self): if self.is_model_parallel_enabled: if self.args.group_by_length: