mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add support for past states (#5399)
* Add support for past states * Style and forgotten self * You mean, documenting is not enough? I have to actually add it too? * Add memory support during evaluation * Fix tests in eval and add TF support * No need to change this line anymore
This commit is contained in:
parent
4ade7491f4
commit
64e3d966b1
@ -493,6 +493,10 @@ class Trainer:
|
||||
else:
|
||||
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
|
||||
|
||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
for step, inputs in enumerate(epoch_iterator):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
@ -575,6 +579,9 @@ class Trainer:
|
||||
|
||||
if self.tb_writer:
|
||||
self.tb_writer.close()
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of training
|
||||
delattr(self, "_past")
|
||||
|
||||
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
||||
return TrainOutput(self.global_step, tr_loss / self.global_step)
|
||||
@ -617,9 +624,15 @@ class Trainer:
|
||||
if isinstance(v, torch.Tensor):
|
||||
inputs[k] = v.to(self.args.device)
|
||||
|
||||
if self.args.past_index >= 0 and self._past is not None:
|
||||
inputs["mems"] = self._past
|
||||
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
if self.args.gradient_accumulation_steps > 1:
|
||||
@ -802,12 +815,17 @@ class Trainer:
|
||||
if is_torch_tpu_available():
|
||||
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
||||
|
||||
if self.args.past_index >= 0:
|
||||
past = None
|
||||
|
||||
for inputs in tqdm(dataloader, desc=description):
|
||||
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
|
||||
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
inputs[k] = v.to(self.args.device)
|
||||
if self.args.past_index >= 0:
|
||||
inputs["mems"] = past
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
@ -816,6 +834,8 @@ class Trainer:
|
||||
eval_losses += [step_eval_loss.mean().item()]
|
||||
else:
|
||||
logits = outputs[0]
|
||||
if self.args.past_index >= 0:
|
||||
past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
||||
|
||||
if not prediction_loss_only:
|
||||
if preds is None:
|
||||
|
@ -240,6 +240,10 @@ class TFTrainer:
|
||||
|
||||
step: int = 1
|
||||
|
||||
# Reset the past mems state at the beginning of the evaluation if necessary.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
for features, labels in dataset:
|
||||
step = tf.convert_to_tensor(step, dtype=tf.int64)
|
||||
loss, logits = self._evaluate_steps(features, labels)
|
||||
@ -288,6 +292,10 @@ class TFTrainer:
|
||||
if not key.startswith("eval_"):
|
||||
metrics[f"eval_{key}"] = metrics.pop(key)
|
||||
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of training
|
||||
delattr(self, "_past")
|
||||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
||||
def _log(self, logs: Dict[str, float]) -> None:
|
||||
@ -405,6 +413,9 @@ class TFTrainer:
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
for epoch_iter in range(epochs_trained, int(epochs + 1)):
|
||||
# Reset the past mems state at the beginning of each epoch if necessary.
|
||||
if self.args.past_index >= 0:
|
||||
self._past = None
|
||||
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
|
||||
self.global_step = iterations.numpy()
|
||||
self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
|
||||
@ -444,6 +455,10 @@ class TFTrainer:
|
||||
if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
|
||||
break
|
||||
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of training
|
||||
delattr(self, "_past")
|
||||
|
||||
def _training_steps(self, ds, optimizer):
|
||||
"""
|
||||
Returns a generator over training steps (i.e. parameters update).
|
||||
@ -518,10 +533,15 @@ class TFTrainer:
|
||||
labels: the batched labels.
|
||||
training: run the model in training mode or not
|
||||
"""
|
||||
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
|
||||
features["mems"] = self._past
|
||||
if isinstance(labels, (dict)):
|
||||
loss, logits = self.model(features, training=training, **labels)[:2]
|
||||
outputs = self.model(features, training=training, **labels)[:2]
|
||||
else:
|
||||
loss, logits = self.model(features, labels=labels, training=training)[:2]
|
||||
outputs = self.model(features, labels=labels, training=training)[:2]
|
||||
loss, logits = outputs[:2]
|
||||
if self.args.past_index >= 0:
|
||||
self._past = outputs[self.args.past_index]
|
||||
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
|
||||
|
||||
return loss, logits
|
||||
|
@ -102,6 +102,11 @@ class TrainingArguments:
|
||||
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
||||
or not.
|
||||
past_index (:obj:`int`, `optional`, defaults to -1):
|
||||
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
|
||||
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
||||
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
|
||||
at the next training step under the keyword argument ``mems``.
|
||||
"""
|
||||
|
||||
output_dir: str = field(
|
||||
@ -203,6 +208,11 @@ class TrainingArguments:
|
||||
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
|
||||
)
|
||||
|
||||
past_index: int = field(
|
||||
default=-1,
|
||||
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
|
||||
)
|
||||
|
||||
@property
|
||||
def train_batch_size(self) -> int:
|
||||
"""
|
||||
|
@ -85,6 +85,11 @@ class TFTrainingArguments(TrainingArguments):
|
||||
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
|
||||
or not.
|
||||
past_index (:obj:`int`, `optional`, defaults to -1):
|
||||
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
|
||||
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
|
||||
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
|
||||
at the next training step under the keyword argument ``mems``.
|
||||
tpu_name (:obj:`str`, `optional`):
|
||||
The name of the TPU the process is running on.
|
||||
eval_steps (:obj:`int`, `optional`, defaults to 1000):
|
||||
|
Loading…
Reference in New Issue
Block a user