Remove unhelpful bart warning (#7391)

This commit is contained in:
Sam Shleifer 2020-09-25 11:01:07 -04:00 committed by GitHub
parent 5ff0d6d7d0
commit 38a1b03f4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -550,13 +550,8 @@ class BartDecoder(nn.Module):
positions = self.embed_positions(input_ids, use_cache=use_cache)
if use_cache:
if input_ids.shape[1] != 1 or past_key_values is None:
# if you make this an AssertionError, test_benchmark breaks.
warnings.warn("pass decoder_past_key_value_states to use_cache")
input_ids = input_ids[:, -1:]
positions = positions[:, -1:] # happens after we embed them
# assert input_ids.ne(self.padding_idx).any()
positions = positions[:, -1:]
x = self.embed_tokens(input_ids) * self.embed_scale
x += positions