diff --git a/src/transformers/commands/chat.py b/src/transformers/commands/chat.py index 91979590046..8f6f49f26bc 100644 --- a/src/transformers/commands/chat.py +++ b/src/transformers/commands/chat.py @@ -333,6 +333,11 @@ class ChatCommand(BaseTransformersCLICommand): ) args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1) + + if args.model_name_or_path is None: + raise ValueError( + "When connecting to a server, please specify a model name with the --model_name_or_path flag." + ) else: self.spawn_backend = True args.model_name_or_path = args.model_name_or_path_or_address diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 9b886b27210..f8b4131a463 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -623,7 +623,7 @@ class ServeCommand(BaseTransformersCLICommand): self.loaded_model = model_id_and_revision - print("Loaded model", model_id_and_revision) + logger.warning(f"Loaded model {model_id_and_revision}") return model, tokenizer diff --git a/tests/commands/test_chat.py b/tests/commands/test_chat.py index 6ba3413fafa..e07df4a3938 100644 --- a/tests/commands/test_chat.py +++ b/tests/commands/test_chat.py @@ -29,12 +29,34 @@ class ChatCLITest(unittest.TestCase): self.assertIn("chat interface", cs.out.lower()) @patch.object(ChatCommand, "run") - def test_cli_dispatch(self, run_mock): + def test_cli_dispatch_model(self, run_mock): + """ + Running transformers chat with just a model should work & spawn a serve underneath + """ args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"] with patch("sys.argv", args): cli.main() run_mock.assert_called_once() + def test_cli_dispatch_url(self): + """ + Running transformers chat with just a URL should not work as a model should additionally be specified + """ + args = ["transformers", "chat", "localhost:8000"] + with self.assertRaises(ValueError): + with patch("sys.argv", args): + cli.main() + + @patch.object(ChatCommand, "run") + def test_cli_dispatch_url_and_model(self, run_mock): + """ + Running transformers chat with a URL and a model should work + """ + args = ["transformers", "chat", "localhost:8000", "--model_name_or_path=hf-internal-testing/tiny-random-gpt2"] + with patch("sys.argv", args): + cli.main() + run_mock.assert_called_once() + def test_parsed_args(self): with ( patch.object(ChatCommand, "__init__", return_value=None) as init_mock,