mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
fixed typo
This commit is contained in:
parent
a2c8e516c2
commit
374deef48d
@ -855,7 +855,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
cur_len = 1
|
||||
|
||||
# put model in generation mode if it has one
|
||||
if hasattr(self.model, "generation_mode"):
|
||||
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "generation_mode"):
|
||||
self.model.decoder.generation_mode = True
|
||||
else:
|
||||
encoder_inputs = None
|
||||
|
@ -287,7 +287,7 @@ class BartHeadTests(unittest.TestCase):
|
||||
new_input_ids = lm_model.generate(
|
||||
input_ids.clone(), num_return_sequences=1, num_beams=2, no_repeat_ngram_size=3, max_length=max_length
|
||||
)
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length - 1))
|
||||
self.assertEqual(new_input_ids.shape, (input_ids.shape[0], max_length))
|
||||
# TODO(SS): uneven length batches, empty inputs
|
||||
|
||||
def test_shift_tokens_right(self):
|
||||
|
Loading…
Reference in New Issue
Block a user