mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Removes SageMakerTrainer code but keeps class as wrapper (#11587)
* removed all old code * make quality
This commit is contained in:
parent
084a187da3
commit
226e74b610
@ -11,72 +11,15 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from torch.utils.data.dataset import Dataset
|
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
|
|
||||||
from ..file_utils import WEIGHTS_NAME, is_torch_tpu_available
|
|
||||||
from ..modeling_utils import PreTrainedModel, unwrap_model
|
|
||||||
from ..trainer import Trainer
|
from ..trainer import Trainer
|
||||||
from ..trainer_pt_utils import (
|
|
||||||
DistributedLengthGroupedSampler,
|
|
||||||
DistributedSamplerWithLoop,
|
|
||||||
SequentialDistributedSampler,
|
|
||||||
nested_detach,
|
|
||||||
nested_numpify,
|
|
||||||
reissue_pt_warnings,
|
|
||||||
)
|
|
||||||
from ..trainer_utils import PREFIX_CHECKPOINT_DIR
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .training_args_sm import is_sagemaker_model_parallel_available
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if is_sagemaker_model_parallel_available():
|
|
||||||
import smdistributed.modelparallel.torch as smp
|
|
||||||
|
|
||||||
@smp.step()
|
|
||||||
def forward_backward(model, inputs, gradient_accumulation_steps=1):
|
|
||||||
outputs = model(**inputs)
|
|
||||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
|
||||||
loss /= gradient_accumulation_steps
|
|
||||||
model.backward(loss)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
@smp.step()
|
|
||||||
def forward_only(model, inputs):
|
|
||||||
return model(**inputs)
|
|
||||||
|
|
||||||
def smp_gather(tensor):
|
|
||||||
if isinstance(tensor, (list, tuple)):
|
|
||||||
return type(tensor)(smp_gather(t) for t in tensor)
|
|
||||||
elif isinstance(tensor, dict):
|
|
||||||
return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
|
|
||||||
elif not isinstance(tensor, torch.Tensor):
|
|
||||||
raise TypeError(
|
|
||||||
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
|
|
||||||
)
|
|
||||||
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
|
|
||||||
return torch.cat([t.cpu() for t in all_tensors], dim=0)
|
|
||||||
|
|
||||||
def nested_smp_concat(tensor):
|
|
||||||
if isinstance(tensor, (list, tuple)):
|
|
||||||
return type(tensor)(nested_smp_concat(t) for t in tensor)
|
|
||||||
elif isinstance(tensor, dict):
|
|
||||||
return type(tensor)({k: nested_smp_concat(v) for k, v in tensor.items()})
|
|
||||||
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
|
|
||||||
# which is also the name of the decorator so Python is confused.
|
|
||||||
return tensor.concat().detach().cpu()
|
|
||||||
|
|
||||||
|
|
||||||
class SageMakerTrainer(Trainer):
|
class SageMakerTrainer(Trainer):
|
||||||
def __init__(self, args=None, **kwargs):
|
def __init__(self, args=None, **kwargs):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
@ -84,239 +27,4 @@ class SageMakerTrainer(Trainer):
|
|||||||
"instead.",
|
"instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
)
|
)
|
||||||
self.is_model_parallel_enabled = is_sagemaker_model_parallel_available()
|
|
||||||
super().__init__(args=args, **kwargs)
|
super().__init__(args=args, **kwargs)
|
||||||
|
|
||||||
def is_world_process_zero(self) -> bool:
|
|
||||||
"""
|
|
||||||
Whether or not this process is the global main process (when training in a distributed fashion on several
|
|
||||||
machines, this is only going to be :obj:`True` for one process).
|
|
||||||
"""
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0
|
|
||||||
else:
|
|
||||||
return super().is_world_process_zero()
|
|
||||||
|
|
||||||
def _get_train_sampler(self):
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
if self.args.group_by_length:
|
|
||||||
return DistributedLengthGroupedSampler(
|
|
||||||
self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()
|
|
||||||
)
|
|
||||||
elif not self.args.dataloader_drop_last:
|
|
||||||
return DistributedSamplerWithLoop(
|
|
||||||
self.train_dataset,
|
|
||||||
self.args.per_device_train_batch_size,
|
|
||||||
num_replicas=smp.dp_size(),
|
|
||||||
rank=smp.dp_rank(),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
|
|
||||||
else:
|
|
||||||
return super()._get_train_sampler()
|
|
||||||
|
|
||||||
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
return SequentialDistributedSampler(
|
|
||||||
eval_dataset,
|
|
||||||
num_replicas=smp.dp_size(),
|
|
||||||
rank=smp.dp_rank(),
|
|
||||||
batch_size=self.args.per_device_eval_batch_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return super()._get_eval_sampler(eval_dataset)
|
|
||||||
|
|
||||||
def _wrap_model(self, model, training=True):
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
# Wrapping the base model twice in a DistributedModel will raise an error.
|
|
||||||
if isinstance(self.model_wrapped, smp.model.DistributedModel):
|
|
||||||
return self.model_wrapped
|
|
||||||
return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
|
|
||||||
else:
|
|
||||||
return super()._wrap_model(model)
|
|
||||||
|
|
||||||
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
|
||||||
super().create_optimizer_and_scheduler(num_training_steps)
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
|
||||||
|
|
||||||
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
model.train()
|
|
||||||
inputs = self._prepare_inputs(inputs)
|
|
||||||
loss_mb = forward_backward(model, inputs, self.args.gradient_accumulation_steps)
|
|
||||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
|
||||||
else:
|
|
||||||
return super().training_step(model, inputs)
|
|
||||||
|
|
||||||
def _gather_and_numpify(self, tensors, name):
|
|
||||||
if tensors is None:
|
|
||||||
return
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
tensors = smp_gather(tensors)
|
|
||||||
return nested_numpify(tensors)
|
|
||||||
else:
|
|
||||||
return super()._gather_and_numpify(tensors, name)
|
|
||||||
|
|
||||||
def save_model(self, output_dir: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
Will save the model, so you can reload it using :obj:`from_pretrained()`.
|
|
||||||
|
|
||||||
Will only save from the world_master process (unless in TPUs).
|
|
||||||
"""
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
self._save_smp(output_dir)
|
|
||||||
elif is_torch_tpu_available():
|
|
||||||
self._save_tpu(output_dir)
|
|
||||||
elif self.is_world_process_zero():
|
|
||||||
self._save(output_dir)
|
|
||||||
|
|
||||||
# If on sagemaker and we are saving the main model (not a checkpoint so output_dir=None), save a copy to
|
|
||||||
# SM_MODEL_DIR for easy deployment.
|
|
||||||
if output_dir is None and os.getenv("SM_MODEL_DIR") is not None:
|
|
||||||
self.save_model(output_dir=os.getenv("SM_MODEL_DIR"))
|
|
||||||
|
|
||||||
def _save_smp(self, output_dir: Optional[str] = None):
|
|
||||||
if smp.dp_rank() != 0:
|
|
||||||
return
|
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
|
||||||
# Calling the state_dict needs to be done on the wrapped model
|
|
||||||
state_dict = self.model_wrapped.state_dict()
|
|
||||||
|
|
||||||
# Rest of the save is done for the main process only
|
|
||||||
if self.is_world_process_zero():
|
|
||||||
model = self.model
|
|
||||||
if not isinstance(model, PreTrainedModel):
|
|
||||||
model = unwrap_model(model)
|
|
||||||
if isinstance(model, PreTrainedModel):
|
|
||||||
model.save_pretrained(output_dir, state_dict=state_dict)
|
|
||||||
else:
|
|
||||||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
|
|
||||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
|
||||||
|
|
||||||
if self.tokenizer is not None:
|
|
||||||
self.tokenizer.save_pretrained(output_dir)
|
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
|
||||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
|
||||||
|
|
||||||
def _save_checkpoint(self, model, trial, metrics=None):
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
if smp.dp_rank() != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
|
|
||||||
|
|
||||||
run_dir = self.args.output_dir
|
|
||||||
self.store_flos()
|
|
||||||
|
|
||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
|
||||||
self.save_model(output_dir)
|
|
||||||
# Consolidate the state dict on all processed of dp_rank 0
|
|
||||||
opt_state_dict = self.optimizer.state_dict()
|
|
||||||
# Save it and the scheduler on the main process
|
|
||||||
if self.is_world_process_zero():
|
|
||||||
torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt"))
|
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
||||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
|
||||||
reissue_pt_warnings(caught_warnings)
|
|
||||||
|
|
||||||
# Determine the new best metric / best model checkpoint
|
|
||||||
if metrics is not None and self.args.metric_for_best_model is not None:
|
|
||||||
metric_to_check = self.args.metric_for_best_model
|
|
||||||
if not metric_to_check.startswith("eval_"):
|
|
||||||
metric_to_check = f"eval_{metric_to_check}"
|
|
||||||
metric_value = metrics[metric_to_check]
|
|
||||||
|
|
||||||
operator = np.greater if self.args.greater_is_better else np.less
|
|
||||||
if (
|
|
||||||
self.state.best_metric is None
|
|
||||||
or self.state.best_model_checkpoint is None
|
|
||||||
or operator(metric_value, self.state.best_metric)
|
|
||||||
):
|
|
||||||
self.state.best_metric = metric_value
|
|
||||||
self.state.best_model_checkpoint = output_dir
|
|
||||||
|
|
||||||
# Save the Trainer state
|
|
||||||
if self.is_world_process_zero():
|
|
||||||
self.state.save_to_json(os.path.join(output_dir, "trainer_state.json"))
|
|
||||||
|
|
||||||
# Maybe delete some older checkpoints.
|
|
||||||
if self.is_world_process_zero():
|
|
||||||
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
|
|
||||||
else:
|
|
||||||
super()._save_checkpoint(self, model, trial, metrics=metrics)
|
|
||||||
|
|
||||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
|
||||||
"""If optimizer and scheduler states exist, load them."""
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
if checkpoint is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if os.path.isfile(os.path.join(checkpoint, "optimizer.pt")) and os.path.isfile(
|
|
||||||
os.path.join(checkpoint, "scheduler.pt")
|
|
||||||
):
|
|
||||||
self.optimizer.load_state_dict(
|
|
||||||
torch.load(os.path.join(checkpoint, "optimizer.pt"), map_location="cpu")
|
|
||||||
)
|
|
||||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
|
||||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, "scheduler.pt")))
|
|
||||||
reissue_pt_warnings(caught_warnings)
|
|
||||||
else:
|
|
||||||
super()._load_optimizer_and_scheduler(checkpoint)
|
|
||||||
|
|
||||||
def prediction_step(
|
|
||||||
self,
|
|
||||||
model: nn.Module,
|
|
||||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
|
||||||
prediction_loss_only: bool,
|
|
||||||
ignore_keys: Optional[List[str]] = None,
|
|
||||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
||||||
if self.is_model_parallel_enabled:
|
|
||||||
has_labels = all(inputs.get(k) is not None for k in self.label_names)
|
|
||||||
inputs = self._prepare_inputs(inputs)
|
|
||||||
|
|
||||||
if ignore_keys is None:
|
|
||||||
if hasattr(self.model, "config"):
|
|
||||||
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
|
||||||
else:
|
|
||||||
ignore_keys = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
raw_outputs = forward_only(model, inputs)
|
|
||||||
if has_labels:
|
|
||||||
if isinstance(raw_outputs, dict):
|
|
||||||
loss_mb = raw_outputs["loss"]
|
|
||||||
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
|
|
||||||
else:
|
|
||||||
loss_mb = raw_outputs[0]
|
|
||||||
logits_mb = raw_outputs[1:]
|
|
||||||
|
|
||||||
loss = loss_mb.reduce_mean().detach().cpu()
|
|
||||||
logits = nested_smp_concat(logits_mb)
|
|
||||||
else:
|
|
||||||
loss = None
|
|
||||||
if isinstance(raw_outputs, dict):
|
|
||||||
logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
|
|
||||||
else:
|
|
||||||
logits_mb = raw_outputs
|
|
||||||
logits = nested_smp_concat(logits_mb)
|
|
||||||
|
|
||||||
if prediction_loss_only:
|
|
||||||
return (loss, None, None)
|
|
||||||
|
|
||||||
if len(logits) == 1:
|
|
||||||
logits = logits[0]
|
|
||||||
|
|
||||||
if has_labels:
|
|
||||||
labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
|
||||||
if len(labels) == 1:
|
|
||||||
labels = labels[0]
|
|
||||||
else:
|
|
||||||
labels = None
|
|
||||||
|
|
||||||
return (loss, logits, labels)
|
|
||||||
else:
|
|
||||||
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user