Replace assertions with ValueError exceptions (#14061)

* Replace assertions with ValueError exceptions

* Format error messages as suggested
This commit is contained in:
David del Río Medina 2021-10-21 13:32:27 +02:00 committed by GitHub
parent 9e4ea25175
commit 3187228206
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -479,7 +479,8 @@ class GenerationMixin:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
if is_encoder_decoder:
assert encoder_outputs is not None
if encoder_outputs is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
)
@ -1327,7 +1328,8 @@ class GenerationMixin:
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
@ -1567,7 +1569,8 @@ class GenerationMixin:
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
@ -1761,9 +1764,10 @@ class GenerationMixin:
batch_beam_size, cur_len = input_ids.shape
assert (
num_beams * batch_size == batch_beam_size
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
@ -2361,9 +2365,10 @@ class GenerationMixin:
batch_beam_size, cur_len = input_ids.shape
assert (
num_beams * batch_size == batch_beam_size
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in