diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index be047f9f5cf..a6dff34ebc5 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -19,6 +19,7 @@ import types from contextlib import contextmanager from datetime import datetime from functools import lru_cache +from types import NoneType from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints from packaging import version @@ -77,6 +78,7 @@ def _get_json_schema_type(param_type: str) -> Dict[str, str]: float: {"type": "number"}, str: {"type": "string"}, bool: {"type": "boolean"}, + NoneType: {"type": "null"}, Any: {}, } if is_vision_available(): diff --git a/tests/utils/test_chat_template_utils.py b/tests/utils/test_chat_template_utils.py index 1816ddd9512..730f64859d7 100644 --- a/tests/utils/test_chat_template_utils.py +++ b/tests/utils/test_chat_template_utils.py @@ -419,6 +419,31 @@ class JsonSchemaGeneratorTest(unittest.TestCase): self.assertEqual(schema["function"], expected_schema) + def test_return_none(self): + def fn(x: int) -> None: + """ + Test function + + Args: + x: The first input + """ + pass + + schema = get_json_schema(fn) + expected_schema = { + "name": "fn", + "description": "Test function", + "parameters": { + "type": "object", + "properties": { + "x": {"type": "integer", "description": "The first input"}, + }, + "required": ["x"], + }, + "return": {"type": "null"}, + } + self.assertEqual(schema["function"], expected_schema) + def test_everything_all_at_once(self): def fn( x: str, y: Optional[List[Union[str, int]]], z: Tuple[Union[str, int], str] = (42, "hello")