mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-24 23:08:57 +06:00
[serve] Model name or path should be required (#39178)
* Model name or path should be required * Fix + add tests * Change print to log so it doesn't display in transformers chat
This commit is contained in:
parent
2d561713f8
commit
548794b886
@ -333,6 +333,11 @@ class ChatCommand(BaseTransformersCLICommand):
|
|||||||
)
|
)
|
||||||
|
|
||||||
args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1)
|
args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1)
|
||||||
|
|
||||||
|
if args.model_name_or_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
"When connecting to a server, please specify a model name with the --model_name_or_path flag."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.spawn_backend = True
|
self.spawn_backend = True
|
||||||
args.model_name_or_path = args.model_name_or_path_or_address
|
args.model_name_or_path = args.model_name_or_path_or_address
|
||||||
|
@ -623,7 +623,7 @@ class ServeCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
self.loaded_model = model_id_and_revision
|
self.loaded_model = model_id_and_revision
|
||||||
|
|
||||||
print("Loaded model", model_id_and_revision)
|
logger.warning(f"Loaded model {model_id_and_revision}")
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,12 +29,34 @@ class ChatCLITest(unittest.TestCase):
|
|||||||
self.assertIn("chat interface", cs.out.lower())
|
self.assertIn("chat interface", cs.out.lower())
|
||||||
|
|
||||||
@patch.object(ChatCommand, "run")
|
@patch.object(ChatCommand, "run")
|
||||||
def test_cli_dispatch(self, run_mock):
|
def test_cli_dispatch_model(self, run_mock):
|
||||||
|
"""
|
||||||
|
Running transformers chat with just a model should work & spawn a serve underneath
|
||||||
|
"""
|
||||||
args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"]
|
args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"]
|
||||||
with patch("sys.argv", args):
|
with patch("sys.argv", args):
|
||||||
cli.main()
|
cli.main()
|
||||||
run_mock.assert_called_once()
|
run_mock.assert_called_once()
|
||||||
|
|
||||||
|
def test_cli_dispatch_url(self):
|
||||||
|
"""
|
||||||
|
Running transformers chat with just a URL should not work as a model should additionally be specified
|
||||||
|
"""
|
||||||
|
args = ["transformers", "chat", "localhost:8000"]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
with patch("sys.argv", args):
|
||||||
|
cli.main()
|
||||||
|
|
||||||
|
@patch.object(ChatCommand, "run")
|
||||||
|
def test_cli_dispatch_url_and_model(self, run_mock):
|
||||||
|
"""
|
||||||
|
Running transformers chat with a URL and a model should work
|
||||||
|
"""
|
||||||
|
args = ["transformers", "chat", "localhost:8000", "--model_name_or_path=hf-internal-testing/tiny-random-gpt2"]
|
||||||
|
with patch("sys.argv", args):
|
||||||
|
cli.main()
|
||||||
|
run_mock.assert_called_once()
|
||||||
|
|
||||||
def test_parsed_args(self):
|
def test_parsed_args(self):
|
||||||
with (
|
with (
|
||||||
patch.object(ChatCommand, "__init__", return_value=None) as init_mock,
|
patch.object(ChatCommand, "__init__", return_value=None) as init_mock,
|
||||||
|
Loading…
Reference in New Issue
Block a user