mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 18:22:34 +06:00
Enable conversational
pipeline for GPTSw3Tokenizer
(#24648)
* feat: Add `_build_conversation_input_ids` to GPT-SW3 tokenizer, adjust line length * feat: Merge in PR https://github.com/huggingface/transformers/pull/24504. This allows the GPT-SW3 models (and other GPT-2 based models) to be 4-bit quantised using `load_in_4bit` with `bitsandbytes`. * fix: F-string * fix: F-string * fix: Remove EOS token from all responses * fix: Remove redundant newlines * feat: Add `load_in_4bit` to `Pipeline` * fix: Separate turns with `\n<s>\n` rather than `<s>` * fix: Add missing newline in prompt * tests: Add unit tests for the new `_build_conversation_input_ids` method * style: Automatic style correction * tests: Compare encodings rather than decodings * fix: Remove `load_in_4bit` from pipeline arguments * docs: Add description and references of the GPT-SW3 chat format * style: Line breaks * Apply suggestions from code review Fix Conversation type hints Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix: Import TYPE_CHECKING * style: Run automatic fixes * tests: Remove `_build_conversation_input_ids` unit tests * tests: Remove import of `Conversation` in GPT-SW3 unit test * style: Revert formatting * style: Move TYPE_CHECKING line after all imports * style: Imports order * fix: Change prompt to ensure that `sp_model.encode` and `encode` yields same result * docs: Add TODO comment related to the addition of whitespace during decoding * style: Automatic style checks * fix: Remove final whitespace in prompt, as prefix whitespace is used by sentencepiece --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
f614b6e393
commit
abaca9f943
@ -1,20 +1,23 @@
|
|||||||
|
"""The tokenizer used by the GPT-SW3 models."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from ... import is_torch_available
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
from ...tokenization_utils import PreTrainedTokenizer
|
||||||
|
from ...utils import is_torch_available, logging
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from shutil import copyfile
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import sentencepiece as spm
|
if TYPE_CHECKING:
|
||||||
|
from transformers.pipelines.conversational import Conversation
|
||||||
from ...tokenization_utils import PreTrainedTokenizer
|
|
||||||
from ...utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -230,8 +233,10 @@ class GPTSw3Tokenizer(PreTrainedTokenizer):
|
|||||||
for token in tokens:
|
for token in tokens:
|
||||||
# make sure that special tokens are not decoded using sentencepiece model
|
# make sure that special tokens are not decoded using sentencepiece model
|
||||||
if token in self.all_special_tokens:
|
if token in self.all_special_tokens:
|
||||||
|
# TODO: Check if this is needed, as it ensures that decode(encode(doc)) != doc by adding extra whitespace in the decoded document
|
||||||
if not prev_is_special:
|
if not prev_is_special:
|
||||||
out_string += " "
|
out_string += " "
|
||||||
|
|
||||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||||
prev_is_special = True
|
prev_is_special = True
|
||||||
current_sub_tokens = []
|
current_sub_tokens = []
|
||||||
@ -312,3 +317,32 @@ class GPTSw3Tokenizer(PreTrainedTokenizer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
return self.sp_model.decode(token_ids)
|
return self.sp_model.decode(token_ids)
|
||||||
|
|
||||||
|
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||||
|
"""Builds the input ids for a conversation.
|
||||||
|
|
||||||
|
This is the format used in the original GPT-SW3 paper [1] and which is also mentioned in the model card [2].
|
||||||
|
The format is inspired by the ChatML format [3]. Concretely, the chat format is set up as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
<eos><bos>User: Jag tycker träd är fina<bos>Bot: Kul att du tycker det!<bos>...
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation (`Conversation`):
|
||||||
|
Conversation to build input ids for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[int]`:
|
||||||
|
Input ids for the conversation.
|
||||||
|
|
||||||
|
References:
|
||||||
|
- [1] https://doi.org/10.48550/arXiv.2305.12987
|
||||||
|
- [2] https://huggingface.co/AI-Sweden-Models/gpt-sw3-126m-instruct
|
||||||
|
- [3] https://github.com/openai/openai-python/blob/main/chatml.md
|
||||||
|
"""
|
||||||
|
all_responses = [f"User: {text}" if is_user else f"Bot: {text}" for is_user, text in conversation.iter_texts()]
|
||||||
|
prompt = (
|
||||||
|
f"{self.eos_token}{self.bos_token}" + f"{self.bos_token}".join(all_responses) + f"{self.bos_token}Bot:"
|
||||||
|
)
|
||||||
|
return self.encode(text=prompt)
|
||||||
|
Loading…
Reference in New Issue
Block a user