mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enable conversational
pipeline for GPTSw3Tokenizer
(#24648)
* feat: Add `_build_conversation_input_ids` to GPT-SW3 tokenizer, adjust line length * feat: Merge in PR https://github.com/huggingface/transformers/pull/24504. This allows the GPT-SW3 models (and other GPT-2 based models) to be 4-bit quantised using `load_in_4bit` with `bitsandbytes`. * fix: F-string * fix: F-string * fix: Remove EOS token from all responses * fix: Remove redundant newlines * feat: Add `load_in_4bit` to `Pipeline` * fix: Separate turns with `\n<s>\n` rather than `<s>` * fix: Add missing newline in prompt * tests: Add unit tests for the new `_build_conversation_input_ids` method * style: Automatic style correction * tests: Compare encodings rather than decodings * fix: Remove `load_in_4bit` from pipeline arguments * docs: Add description and references of the GPT-SW3 chat format * style: Line breaks * Apply suggestions from code review Fix Conversation type hints Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix: Import TYPE_CHECKING * style: Run automatic fixes * tests: Remove `_build_conversation_input_ids` unit tests * tests: Remove import of `Conversation` in GPT-SW3 unit test * style: Revert formatting * style: Move TYPE_CHECKING line after all imports * style: Imports order * fix: Change prompt to ensure that `sp_model.encode` and `encode` yields same result * docs: Add TODO comment related to the addition of whitespace during decoding * style: Automatic style checks * fix: Remove final whitespace in prompt, as prefix whitespace is used by sentencepiece --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
f614b6e393
commit
abaca9f943
@ -1,20 +1,23 @@
|
||||
"""The tokenizer used by the GPT-SW3 models."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
from shutil import copyfile
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from ... import is_torch_available
|
||||
import sentencepiece as spm
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from ...tokenization_utils import PreTrainedTokenizer
|
||||
from ...utils import logging
|
||||
if TYPE_CHECKING:
|
||||
from transformers.pipelines.conversational import Conversation
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -230,8 +233,10 @@ class GPTSw3Tokenizer(PreTrainedTokenizer):
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
# TODO: Check if this is needed, as it ensures that decode(encode(doc)) != doc by adding extra whitespace in the decoded document
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
@ -312,3 +317,32 @@ class GPTSw3Tokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
|
||||
return self.sp_model.decode(token_ids)
|
||||
|
||||
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
|
||||
"""Builds the input ids for a conversation.
|
||||
|
||||
This is the format used in the original GPT-SW3 paper [1] and which is also mentioned in the model card [2].
|
||||
The format is inspired by the ChatML format [3]. Concretely, the chat format is set up as follows:
|
||||
|
||||
```
|
||||
<eos><bos>User: Jag tycker träd är fina<bos>Bot: Kul att du tycker det!<bos>...
|
||||
```
|
||||
|
||||
Args:
|
||||
conversation (`Conversation`):
|
||||
Conversation to build input ids for.
|
||||
|
||||
Returns:
|
||||
`List[int]`:
|
||||
Input ids for the conversation.
|
||||
|
||||
References:
|
||||
- [1] https://doi.org/10.48550/arXiv.2305.12987
|
||||
- [2] https://huggingface.co/AI-Sweden-Models/gpt-sw3-126m-instruct
|
||||
- [3] https://github.com/openai/openai-python/blob/main/chatml.md
|
||||
"""
|
||||
all_responses = [f"User: {text}" if is_user else f"Bot: {text}" for is_user, text in conversation.iter_texts()]
|
||||
prompt = (
|
||||
f"{self.eos_token}{self.bos_token}" + f"{self.bos_token}".join(all_responses) + f"{self.bos_token}Bot:"
|
||||
)
|
||||
return self.encode(text=prompt)
|
||||
|
Loading…
Reference in New Issue
Block a user