diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index 91979590046..8f6f49f26bc 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -333,6 +333,11 @@ class ChatCommand(BaseTransformersCLICommand): ) args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1) + + if args.model_name_or_path is None: + raise ValueError( + "When connecting to a server, please specify a model name with the --model_name_or_path flag." + ) else: self.spawn_backend = True args.model_name_or_path = args.model_name_or_path_or_address diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 0eab86f6123..d8f61603692 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -628,7 +628,7 @@ class ServeCommand(BaseTransformersCLICommand): self.loaded_model = f"{model_id}@{revision}" - print("Loaded model", self.loaded_model) + logger.warning(f"Loaded model {self.loaded_model}") return model, tokenizer diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index de6ae44bb5a..13a1c83a719 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3773,16 +3773,28 @@ class GenerationMixin(ContinuousMixin): Beam Search stopping condition -- halts the generation loop if any of these conditions becomes False """ # a. Can the open beams improve the top completed scores? - # early_stopping == False -> apply heuristic = always get the best score from - # `cur_len - decoder_prompt_len`. See the discussion below for more details. - # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + # early_stopping == False -> apply heuristic = always get the best score from `cur_len - decoder_prompt_len`. # early_stopping == "never" -> compute the best score from `max_length` or `cur_len`, depending on the # sign of `length_penalty`. Positive `length_penalty` favors longer sequences, thus we use # `max_length` there. + # !! + # Be sure to check the docstring for `early_stopping` and `length_penalty`. The default parameterization + # does NOT correspond to a canonical beam search implementation, and tends to favor shorter output sequences + # compared to it (the heuristic active by default underestimates the maximum achievable score, and thus cut + # generation short). Also, be mindful that length penalty > 0.0 actually favors longer sequences, despite + # its name. These modifications were empirically found in the past (prior to 2022) to produce better quality + # generations, and changing them is BC breaking. + # For a canonical beam search implementation, set `early_stopping="never"` and `length_penalty=0.0`. + # See the discussion below for more details. + # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 + # !! if early_stopping == "never" and length_penalty > 0.0: best_hypothetical_length = max_length - decoder_prompt_len else: best_hypothetical_length = cur_len - decoder_prompt_len + + # best-case scenario: the next tokens have logprobs=0 (probability=1), and the score stays the same before + # applying length penalty best_possible_running_score = running_beam_scores[:, :1] / (best_hypothetical_length**length_penalty) worst_finished_score = torch.where(is_sent_finished, torch.min(beam_scores, dim=1, keepdim=True)[0], -1.0e9) improvement_possible = torch.any(best_possible_running_score > worst_finished_score) diff --git a/tests/commands/test_chat.py b/tests/commands/test_chat.py index 6ba3413fafa..e07df4a3938 100644 --- a/tests/commands/test_chat.py +++ b/tests/commands/test_chat.py @@ -29,12 +29,34 @@ class ChatCLITest(unittest.TestCase): self.assertIn("chat interface", cs.out.lower()) @patch.object(ChatCommand, "run") - def test_cli_dispatch(self, run_mock): + def test_cli_dispatch_model(self, run_mock): + """ + Running transformers chat with just a model should work & spawn a serve underneath + """ args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"] with patch("sys.argv", args): cli.main() run_mock.assert_called_once() + def test_cli_dispatch_url(self): + """ + Running transformers chat with just a URL should not work as a model should additionally be specified + """ + args = ["transformers", "chat", "localhost:8000"] + with self.assertRaises(ValueError): + with patch("sys.argv", args): + cli.main() + + @patch.object(ChatCommand, "run") + def test_cli_dispatch_url_and_model(self, run_mock): + """ + Running transformers chat with a URL and a model should work + """ + args = ["transformers", "chat", "localhost:8000", "--model_name_or_path=hf-internal-testing/tiny-random-gpt2"] + with patch("sys.argv", args): + cli.main() + run_mock.assert_called_once() + def test_parsed_args(self): with ( patch.object(ChatCommand, "__init__", return_value=None) as init_mock,