mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 09:42:22 +06:00
Sagemaker Model Parallel tensoboard writing fix (#10403)
* Added tb fix * Removed local rank condition * Updated reference to args
This commit is contained in:
parent
83d2d55c94
commit
7fc686efb1
@ -71,11 +71,21 @@ if is_smdistributed_available():
|
||||
|
||||
class SageMakerTrainer(Trainer):
|
||||
def __init__(self, args=None, **kwargs):
|
||||
self.is_model_parallel_enabled = is_smdistributed_available() and args.mp_parameters != ""
|
||||
super().__init__(args=args, **kwargs)
|
||||
self.is_model_parallel_enabled = is_smdistributed_available() and self.args.mp_parameters != ""
|
||||
if self.is_model_parallel_enabled and self.args.gradient_accumulation_steps != 1:
|
||||
raise ValueError("Gradient accumulation is not supported when model parallel is enabled.")
|
||||
|
||||
def is_world_process_zero(self) -> bool:
|
||||
"""
|
||||
Whether or not this process is the global main process (when training in a distributed fashion on several
|
||||
machines, this is only going to be :obj:`True` for one process).
|
||||
"""
|
||||
if self.is_model_parallel_enabled:
|
||||
return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0
|
||||
else:
|
||||
return super.is_world_process_zero()
|
||||
|
||||
def _get_train_sampler(self):
|
||||
if self.is_model_parallel_enabled:
|
||||
if self.args.group_by_length:
|
||||
|
Loading…
Reference in New Issue
Block a user