mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
Merge branch 'main' into random-serve-fixes
This commit is contained in:
commit
de606cddde
@ -333,6 +333,11 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
|
||||
args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1)
|
||||
|
||||
if args.model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"When connecting to a server, please specify a model name with the --model_name_or_path flag."
|
||||
)
|
||||
else:
|
||||
self.spawn_backend = True
|
||||
args.model_name_or_path = args.model_name_or_path_or_address
|
||||
|
@ -628,7 +628,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
|
||||
self.loaded_model = f"{model_id}@{revision}"
|
||||
|
||||
print("Loaded model", self.loaded_model)
|
||||
logger.warning(f"Loaded model {self.loaded_model}")
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
@ -3773,16 +3773,28 @@ class GenerationMixin(ContinuousMixin):
|
||||
Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False
|
||||
"""
|
||||
# a. Can the open beams improve the top completed scores?
|
||||
# early_stopping == False -> apply heuristic = always get the best score from
|
||||
# `cur_len - decoder_prompt_len`. See the discussion below for more details.
|
||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||
# early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`.
|
||||
# early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the
|
||||
# sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use
|
||||
# `max_length` there.
|
||||
# !!
|
||||
# Be sure to check the docstring for `early_stopping` and `length_penalty`. The default parameterization
|
||||
# does NOT correspond to a canonical beam search implementation, and tends to favor shorter output sequences
|
||||
# compared to it (the heuristic active by default underestimates the maximum achievable score, and thus cut
|
||||
# generation short). Also, be mindful that length penalty > 0.0 actually favors longer sequences, despite
|
||||
# its name. These modifications were empirically found in the past (prior to 2022) to produce better quality
|
||||
# generations, and changing them is BC breaking.
|
||||
# For a canonical beam search implementation, set `early_stopping="never"` and `length_penalty=0.0`.
|
||||
# See the discussion below for more details.
|
||||
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
||||
# !!
|
||||
if early_stopping == "never" and length_penalty > 0.0:
|
||||
best_hypothetical_length = max_length - decoder_prompt_len
|
||||
else:
|
||||
best_hypothetical_length = cur_len - decoder_prompt_len
|
||||
|
||||
# best-case scenario: the next tokens have logprobs=0 (probability=1), and the score stays the same before
|
||||
# applying length penalty
|
||||
best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty)
|
||||
worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9)
|
||||
improvement_possible = torch.any(best_possible_running_score > worst_finished_score)
|
||||
|
@ -29,12 +29,34 @@ class ChatCLITest(unittest.TestCase):
|
||||
self.assertIn("chat interface", cs.out.lower())
|
||||
|
||||
@patch.object(ChatCommand, "run")
|
||||
def test_cli_dispatch(self, run_mock):
|
||||
def test_cli_dispatch_model(self, run_mock):
|
||||
"""
|
||||
Running transformers chat with just a model should work & spawn a serve underneath
|
||||
"""
|
||||
args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"]
|
||||
with patch("sys.argv", args):
|
||||
cli.main()
|
||||
run_mock.assert_called_once()
|
||||
|
||||
def test_cli_dispatch_url(self):
|
||||
"""
|
||||
Running transformers chat with just a URL should not work as a model should additionally be specified
|
||||
"""
|
||||
args = ["transformers", "chat", "localhost:8000"]
|
||||
with self.assertRaises(ValueError):
|
||||
with patch("sys.argv", args):
|
||||
cli.main()
|
||||
|
||||
@patch.object(ChatCommand, "run")
|
||||
def test_cli_dispatch_url_and_model(self, run_mock):
|
||||
"""
|
||||
Running transformers chat with a URL and a model should work
|
||||
"""
|
||||
args = ["transformers", "chat", "localhost:8000", "--model_name_or_path=hf-internal-testing/tiny-random-gpt2"]
|
||||
with patch("sys.argv", args):
|
||||
cli.main()
|
||||
run_mock.assert_called_once()
|
||||
|
||||
def test_parsed_args(self):
|
||||
with (
|
||||
patch.object(ChatCommand, "__init__", return_value=None) as init_mock,
|
||||
|
Loading…
Reference in New Issue
Block a user