Enable conversational pipeline for GPTSw3Tokenizer (#24648)

* feat: Add `_build_conversation_input_ids` to GPT-SW3 tokenizer, adjust line length

* feat: Merge in PR https://github.com/huggingface/transformers/pull/24504.

This allows the GPT-SW3 models (and other GPT-2 based models) to be 4-bit quantised
using `load_in_4bit` with `bitsandbytes`.

* fix: F-string

* fix: F-string

* fix: Remove EOS token from all responses

* fix: Remove redundant newlines

* feat: Add `load_in_4bit` to `Pipeline`

* fix: Separate turns with `\n<s>\n` rather than `<s>`

* fix: Add missing newline in prompt

* tests: Add unit tests for the new `_build_conversation_input_ids` method

* style: Automatic style correction

* tests: Compare encodings rather than decodings

* fix: Remove `load_in_4bit` from pipeline arguments

* docs: Add description and references of the GPT-SW3 chat format

* style: Line breaks

* Apply suggestions from code review

Fix Conversation type hints

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix: Import TYPE_CHECKING

* style: Run automatic fixes

* tests: Remove `_build_conversation_input_ids` unit tests

* tests: Remove import of `Conversation` in GPT-SW3 unit test

* style: Revert formatting

* style: Move TYPE_CHECKING line after all imports

* style: Imports order

* fix: Change prompt to ensure that `sp_model.encode` and `encode` yields same result

* docs: Add TODO comment related to the addition of whitespace during decoding

* style: Automatic style checks

* fix: Remove final whitespace in prompt, as prefix whitespace is used by sentencepiece

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Dan Saattrup Nielsen 2023-07-07 20:52:21 +02:00 committed by GitHub
parent f614b6e393
commit abaca9f943
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,20 +1,23 @@
"""The tokenizer used by the GPT-SW3 models."""
import os
import re
import unicodedata
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from ... import is_torch_available
import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import is_torch_available, logging
if is_torch_available():
import torch
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import sentencepiece as spm
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
if TYPE_CHECKING:
from transformers.pipelines.conversational import Conversation
logger = logging.get_logger(__name__)
@ -230,8 +233,10 @@ class GPTSw3Tokenizer(PreTrainedTokenizer):
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
# TODO: Check if this is needed, as it ensures that decode(encode(doc)) != doc by adding extra whitespace in the decoded document
if not prev_is_special:
out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
@ -312,3 +317,32 @@ class GPTSw3Tokenizer(PreTrainedTokenizer):
"""
return self.sp_model.decode(token_ids)
def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
"""Builds the input ids for a conversation.
This is the format used in the original GPT-SW3 paper [1] and which is also mentioned in the model card [2].
The format is inspired by the ChatML format [3]. Concretely, the chat format is set up as follows:
```
<eos><bos>User: Jag tycker träd är fina<bos>Bot: Kul att du tycker det!<bos>...
```
Args:
conversation (`Conversation`):
Conversation to build input ids for.
Returns:
`List[int]`:
Input ids for the conversation.
References:
- [1] https://doi.org/10.48550/arXiv.2305.12987
- [2] https://huggingface.co/AI-Sweden-Models/gpt-sw3-126m-instruct
- [3] https://github.com/openai/openai-python/blob/main/chatml.md
"""
all_responses = [f"User: {text}" if is_user else f"Bot: {text}" for is_user, text in conversation.iter_texts()]
prompt = (
f"{self.eos_token}{self.bos_token}" + f"{self.bos_token}".join(all_responses) + f"{self.bos_token}Bot:"
)
return self.encode(text=prompt)