mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Replace assertions with ValueError exceptions (#14061)
* Replace assertions with ValueError exceptions * Format error messages as suggested
This commit is contained in:
parent
9e4ea25175
commit
3187228206
@ -479,7 +479,8 @@ class GenerationMixin:
|
||||
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
|
||||
|
||||
if is_encoder_decoder:
|
||||
assert encoder_outputs is not None
|
||||
if encoder_outputs is None:
|
||||
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
||||
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
||||
)
|
||||
@ -1327,7 +1328,8 @@ class GenerationMixin:
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
@ -1567,7 +1569,8 @@ class GenerationMixin:
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs, and length for next step
|
||||
@ -1761,9 +1764,10 @@ class GenerationMixin:
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
assert (
|
||||
num_beams * batch_size == batch_beam_size
|
||||
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
)
|
||||
|
||||
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
|
||||
beam_scores[:, 1:] = -1e9
|
||||
@ -2361,9 +2365,10 @@ class GenerationMixin:
|
||||
|
||||
batch_beam_size, cur_len = input_ids.shape
|
||||
|
||||
assert (
|
||||
num_beams * batch_size == batch_beam_size
|
||||
), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
if num_beams * batch_size != batch_beam_size:
|
||||
raise ValueError(
|
||||
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
||||
)
|
||||
|
||||
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
|
||||
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
|
||||
|
Loading…
Reference in New Issue
Block a user