From 3187228206cce052c5df0a8643fe85d2fd50e6a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20del=20R=C3=ADo=20Medina?= Date: Thu, 21 Oct 2021 13:32:27 +0200 Subject: [PATCH] Replace assertions with ValueError exceptions (#14061) * Replace assertions with ValueError exceptions * Format error messages as suggested --- src/transformers/generation_utils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index a86a3d94be3..e6cb59ea74c 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -479,7 +479,8 @@ class GenerationMixin: model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) if is_encoder_decoder: - assert encoder_outputs is not None + if encoder_outputs is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) ) @@ -1327,7 +1328,8 @@ class GenerationMixin: # finished sentences should have their next token be a padding token if eos_token_id is not None: - assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step @@ -1567,7 +1569,8 @@ class GenerationMixin: # finished sentences should have their next token be a padding token if eos_token_id is not None: - assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step @@ -1761,9 +1764,10 @@ class GenerationMixin: batch_beam_size, cur_len = input_ids.shape - assert ( - num_beams * batch_size == batch_beam_size - ), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 @@ -2361,9 +2365,10 @@ class GenerationMixin: batch_beam_size, cur_len = input_ids.shape - assert ( - num_beams * batch_size == batch_beam_size - ), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in