Add support for past states (#5399)

* Add support for past states

* Style and forgotten self

* You mean, documenting is not enough? I have to actually add it too?

* Add memory support during evaluation

* Fix tests in eval and add TF support

* No need to change this line anymore
This commit is contained in:
Sylvain Gugger 2020-07-01 08:11:55 -04:00 committed by GitHub
parent 4ade7491f4
commit 64e3d966b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 2 deletions

View File

@ -493,6 +493,10 @@ class Trainer:
else:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master())
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
@ -575,6 +579,9 @@ class Trainer:
if self.tb_writer:
self.tb_writer.close()
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
return TrainOutput(self.global_step, tr_loss / self.global_step)
@ -617,9 +624,15 @@ class Trainer:
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.args.gradient_accumulation_steps > 1:
@ -802,12 +815,17 @@ class Trainer:
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
if self.args.past_index >= 0:
past = None
for inputs in tqdm(dataloader, desc=description):
has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
for k, v in inputs.items():
if isinstance(v, torch.Tensor):
inputs[k] = v.to(self.args.device)
if self.args.past_index >= 0:
inputs["mems"] = past
with torch.no_grad():
outputs = model(**inputs)
@ -816,6 +834,8 @@ class Trainer:
eval_losses += [step_eval_loss.mean().item()]
else:
logits = outputs[0]
if self.args.past_index >= 0:
past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
if not prediction_loss_only:
if preds is None:

View File

@ -240,6 +240,10 @@ class TFTrainer:
step: int = 1
# Reset the past mems state at the beginning of the evaluation if necessary.
if self.args.past_index >= 0:
self._past = None
for features, labels in dataset:
step = tf.convert_to_tensor(step, dtype=tf.int64)
loss, logits = self._evaluate_steps(features, labels)
@ -288,6 +292,10 @@ class TFTrainer:
if not key.startswith("eval_"):
metrics[f"eval_{key}"] = metrics.pop(key)
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _log(self, logs: Dict[str, float]) -> None:
@ -405,6 +413,9 @@ class TFTrainer:
logger.info(" Total optimization steps = %d", t_total)
for epoch_iter in range(epochs_trained, int(epochs + 1)):
# Reset the past mems state at the beginning of each epoch if necessary.
if self.args.past_index >= 0:
self._past = None
for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)):
self.global_step = iterations.numpy()
self.epoch_logging = epoch_iter - 1 + (step + 1) / steps_per_epoch
@ -444,6 +455,10 @@ class TFTrainer:
if self.args.max_steps > 0 and self.global_step % self.args.max_steps == 0:
break
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
def _training_steps(self, ds, optimizer):
"""
Returns a generator over training steps (i.e. parameters update).
@ -518,10 +533,15 @@ class TFTrainer:
labels: the batched labels.
training: run the model in training mode or not
"""
if self.args.past_index >= 0 and getattr(self, "_past", None) is not None:
features["mems"] = self._past
if isinstance(labels, (dict)):
loss, logits = self.model(features, training=training, **labels)[:2]
outputs = self.model(features, training=training, **labels)[:2]
else:
loss, logits = self.model(features, labels=labels, training=training)[:2]
outputs = self.model(features, labels=labels, training=training)[:2]
loss, logits = outputs[:2]
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
loss += sum(self.model.losses) * (1.0 / self.args.n_gpu)
return loss, logits

View File

@ -102,6 +102,11 @@ class TrainingArguments:
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not.
past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
at the next training step under the keyword argument ``mems``.
"""
output_dir: str = field(
@ -203,6 +208,11 @@ class TrainingArguments:
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
)
past_index: int = field(
default=-1,
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
)
@property
def train_batch_size(self) -> int:
"""

View File

@ -85,6 +85,11 @@ class TFTrainingArguments(TrainingArguments):
dataloader_drop_last (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
or not.
past_index (:obj:`int`, `optional`, defaults to -1):
Some models like :doc:`TransformerXL <../model_doc/transformerxl>` or :doc`XLNet <../model_doc/xlnet>` can
make use of the past hidden states for their predictions. If this argument is set to a positive int, the
``Trainer`` will use the corresponding output (usually index 2) as the past state and feed it to the model
at the next training step under the keyword argument ``mems``.
tpu_name (:obj:`str`, `optional`):
The name of the TPU the process is running on.
eval_steps (:obj:`int`, `optional`, defaults to 1000):