From aceb3fbaf412e7ee5ed64b27a89cd15486c76955 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 5 Mar 2020 16:34:47 +0100 Subject: [PATCH] only do output_past=True for language generation in bart --- tests/test_modeling_bart.py | 2 -- tests/test_modeling_common.py | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index c248c1e73d1..2333e5a8c7f 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -64,7 +64,6 @@ class ModelTester: self.eos_token_id = 2 self.pad_token_id = 1 self.bos_token_id = 0 - self.output_past = True torch.manual_seed(0) def prepare_config_and_inputs_for_common(self): @@ -86,7 +85,6 @@ class ModelTester: eos_token_ids=self.eos_token_id, bos_token_id=self.bos_token_id, pad_token_id=self.pad_token_id, - output_past=self.output_past, ) inputs_dict = prepare_bart_inputs_dict(config, input_ids) return config, inputs_dict diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a52d7469476..bc7bc967e28 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -628,6 +628,9 @@ class ModelTesterMixin: "input_ids", None ) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed. + if self.is_encoder_decoder: + config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models + for model_class in self.all_generative_model_classes: model = model_class(config) model.to(torch_device)