mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[s2s] save first batch to json for debugging purposes (#6810)
This commit is contained in:
parent
2b574e7c60
commit
500be01c5d
@ -33,6 +33,7 @@ from utils import (
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
@ -105,6 +106,7 @@ class SummarizationModule(BaseTransformer):
|
||||
self.dataset_class = (
|
||||
Seq2SeqDataset if hasattr(self.tokenizer, "prepare_seq2seq_batch") else LegacySeq2SeqDataset
|
||||
)
|
||||
self.already_saved_batch = False
|
||||
self.eval_beams = self.model.config.num_beams if self.hparams.eval_beams is None else self.hparams.eval_beams
|
||||
if self.hparams.eval_max_gen_length is not None:
|
||||
self.eval_max_length = self.hparams.eval_max_gen_length
|
||||
@ -112,6 +114,17 @@ class SummarizationModule(BaseTransformer):
|
||||
self.eval_max_length = self.model.config.max_length
|
||||
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
|
||||
|
||||
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
|
||||
"""A debugging utility"""
|
||||
readable_batch = {
|
||||
k: self.tokenizer.batch_decode(v.tolist()) if "mask" not in k else v.shape for k, v in batch.items()
|
||||
}
|
||||
save_json(readable_batch, Path(self.output_dir) / "text_batch.json")
|
||||
save_json({k: v.tolist() for k, v in batch.items()}, Path(self.output_dir) / "tok_batch.json")
|
||||
|
||||
self.already_saved_batch = True
|
||||
return readable_batch
|
||||
|
||||
def forward(self, input_ids, **kwargs):
|
||||
return self.model(input_ids, **kwargs)
|
||||
|
||||
@ -129,6 +142,9 @@ class SummarizationModule(BaseTransformer):
|
||||
decoder_input_ids = self.model._shift_right(tgt_ids)
|
||||
else:
|
||||
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
|
||||
if not self.already_saved_batch: # This would be slightly better if it only happened on rank zero
|
||||
batch["decoder_input_ids"] = decoder_input_ids
|
||||
self.save_readable_batch(batch)
|
||||
|
||||
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
||||
lm_logits = outputs[0]
|
||||
|
@ -422,6 +422,10 @@ def test_finetune(model):
|
||||
assert bart.decoder.embed_tokens == bart.encoder.embed_tokens
|
||||
assert bart.decoder.embed_tokens == bart.shared
|
||||
|
||||
example_batch = load_json(module.output_dir / "text_batch.json")
|
||||
assert isinstance(example_batch, dict)
|
||||
assert len(example_batch) >= 4
|
||||
|
||||
|
||||
def test_finetune_extra_model_args():
|
||||
args_d: dict = CHEAP_ARGS.copy()
|
||||
|
Loading…
Reference in New Issue
Block a user