mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
only do output_past=True for language generation in bart
This commit is contained in:
parent
7cba11fb9b
commit
aceb3fbaf4
@ -64,7 +64,6 @@ class ModelTester:
|
|||||||
self.eos_token_id = 2
|
self.eos_token_id = 2
|
||||||
self.pad_token_id = 1
|
self.pad_token_id = 1
|
||||||
self.bos_token_id = 0
|
self.bos_token_id = 0
|
||||||
self.output_past = True
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
@ -86,7 +85,6 @@ class ModelTester:
|
|||||||
eos_token_ids=self.eos_token_id,
|
eos_token_ids=self.eos_token_id,
|
||||||
bos_token_id=self.bos_token_id,
|
bos_token_id=self.bos_token_id,
|
||||||
pad_token_id=self.pad_token_id,
|
pad_token_id=self.pad_token_id,
|
||||||
output_past=self.output_past,
|
|
||||||
)
|
)
|
||||||
inputs_dict = prepare_bart_inputs_dict(config, input_ids)
|
inputs_dict = prepare_bart_inputs_dict(config, input_ids)
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
@ -628,6 +628,9 @@ class ModelTesterMixin:
|
|||||||
"input_ids", None
|
"input_ids", None
|
||||||
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
|
) # TODO (PVP): ugly workaround to make code work for t5 for the moment - has to changed when t5 is fixed.
|
||||||
|
|
||||||
|
if self.is_encoder_decoder:
|
||||||
|
config.output_past = True # needed for Bart TODO: might have to update for other encoder-decoder models
|
||||||
|
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
Loading…
Reference in New Issue
Block a user