[s2s] save first batch to json for debugging purposes (#6810)

This commit is contained in:
Sam Shleifer 2020-10-06 16:11:56 -04:00 committed by GitHub
parent 2b574e7c60
commit 500be01c5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 0 deletions

View File

@ -33,6 +33,7 @@ from utils import (
lmap,
pickle_save,
save_git_info,
save_json,
use_task_specific_params,
)
@ -105,6 +106,7 @@ class SummarizationModule(BaseTransformer):
self.dataset_class = (
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
)
self.already_saved_batch = False
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
if self.hparams.eval_max_gen_length is not None:
self.eval_max_length = self.hparams.eval_max_gen_length
@ -112,6 +114,17 @@ class SummarizationModule(BaseTransformer):
self.eval_max_length = self.model.config.max_length
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
"""A debugging utility"""
readable_batch = {
k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
}
save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json")
self.already_saved_batch = True
return readable_batch
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
@ -129,6 +142,9 @@ class SummarizationModule(BaseTransformer):
decoder_input_ids = self.model._shift_right(tgt_ids)
else:
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero
batch["decoder_input_ids"] = decoder_input_ids
self.save_readable_batch(batch)
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
lm_logits = outputs[0]

View File

@ -422,6 +422,10 @@ def test_finetune(model):
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
assert bart.decoder.embed_tokens == bart.shared
example_batch = load_json(module.output_dir / "text_batch.json")
assert isinstance(example_batch, dict)
assert len(example_batch) >= 4
def test_finetune_extra_model_args():
args_d: dict = CHEAP_ARGS.copy()