mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Remove unhelpful bart warning (#7391)
This commit is contained in:
parent
5ff0d6d7d0
commit
38a1b03f4d
@ -550,13 +550,8 @@ class BartDecoder(nn.Module):
|
||||
positions = self.embed_positions(input_ids, use_cache=use_cache)
|
||||
|
||||
if use_cache:
|
||||
if input_ids.shape[1] != 1 or past_key_values is None:
|
||||
# if you make this an AssertionError, test_benchmark breaks.
|
||||
warnings.warn("pass decoder_past_key_value_states to use_cache")
|
||||
|
||||
input_ids = input_ids[:, -1:]
|
||||
positions = positions[:, -1:] # happens after we embed them
|
||||
# assert input_ids.ne(self.padding_idx).any()
|
||||
positions = positions[:, -1:]
|
||||
|
||||
x = self.embed_tokens(input_ids) * self.embed_scale
|
||||
x += positions
|
||||
|
Loading…
Reference in New Issue
Block a user