mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Merge branch 'main' of github.com:huggingface/transformers into clean-llamas
This commit is contained in:
commit
8c96926f60
2
.github/workflows/check_failed_tests.yml
vendored
2
.github/workflows/check_failed_tests.yml
vendored
@ -41,7 +41,7 @@ jobs:
|
||||
check_new_failures:
|
||||
name: " "
|
||||
runs-on:
|
||||
group: aws-g4dn-4xlarge-cache
|
||||
group: aws-g5-4xlarge-cache
|
||||
container:
|
||||
image: ${{ inputs.docker }}
|
||||
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
|
2
.github/workflows/doctest_job.yml
vendored
2
.github/workflows/doctest_job.yml
vendored
@ -28,7 +28,7 @@ jobs:
|
||||
matrix:
|
||||
split_keys: ${{ fromJson(inputs.split_keys) }}
|
||||
runs-on:
|
||||
group: aws-g4dn-4xlarge-cache
|
||||
group: aws-g5-4xlarge-cache
|
||||
container:
|
||||
image: huggingface/transformers-all-latest-gpu
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
|
2
.github/workflows/doctests.yml
vendored
2
.github/workflows/doctests.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
setup:
|
||||
name: Setup
|
||||
runs-on:
|
||||
group: aws-g4dn-4xlarge-cache
|
||||
group: aws-g5-4xlarge-cache
|
||||
container:
|
||||
image: huggingface/transformers-all-latest-gpu
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
|
||||
|
4
.github/workflows/model_jobs.yml
vendored
4
.github/workflows/model_jobs.yml
vendored
@ -107,9 +107,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ inputs.machine_type }}"
|
||||
|
||||
if [ "${{ inputs.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ inputs.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ inputs.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ inputs.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ inputs.machine_type }}
|
||||
|
12
.github/workflows/self-comment-ci.yml
vendored
12
.github/workflows/self-comment-ci.yml
vendored
@ -185,7 +185,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.get-tests.outputs.models) }}
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -239,9 +239,9 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -292,7 +292,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.get-tests.outputs.quantizations) }}
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -338,9 +338,9 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
|
26
.github/workflows/self-push.yml
vendored
26
.github/workflows/self-push.yml
vendored
@ -31,7 +31,7 @@ jobs:
|
||||
name: Setup
|
||||
strategy:
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-2xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -131,7 +131,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.setup.outputs.matrix) }}
|
||||
machine_type: [aws-g4dn-2xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -169,9 +169,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -244,7 +244,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.setup.outputs.matrix) }}
|
||||
machine_type: [aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -282,9 +282,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -357,7 +357,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-2xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -395,9 +395,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -467,7 +467,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -505,9 +505,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-2xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
|
28
.github/workflows/self-scheduled.yml
vendored
28
.github/workflows/self-scheduled.yml
vendored
@ -50,7 +50,7 @@ jobs:
|
||||
name: Setup
|
||||
strategy:
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -128,7 +128,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
slice_id: [0, 1]
|
||||
uses: ./.github/workflows/model_jobs.yml
|
||||
with:
|
||||
@ -145,7 +145,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -179,9 +179,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -213,7 +213,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-4xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -247,9 +247,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -282,7 +282,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -344,9 +344,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
@ -381,7 +381,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
folders: ${{ fromJson(needs.setup.outputs.quantization_matrix) }}
|
||||
machine_type: [aws-g4dn-4xlarge-cache, aws-g4dn-12xlarge-cache]
|
||||
machine_type: [aws-g5-4xlarge-cache, aws-g5-12xlarge-cache]
|
||||
runs-on:
|
||||
group: '${{ matrix.machine_type }}'
|
||||
container:
|
||||
@ -424,9 +424,9 @@ jobs:
|
||||
run: |
|
||||
echo "${{ matrix.machine_type }}"
|
||||
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g4dn-4xlarge-cache" ]; then
|
||||
if [ "${{ matrix.machine_type }}" = "aws-g5-4xlarge-cache" ]; then
|
||||
machine_type=single-gpu
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g4dn-12xlarge-cache" ]; then
|
||||
elif [ "${{ matrix.machine_type }}" = "aws-g5-12xlarge-cache" ]; then
|
||||
machine_type=multi-gpu
|
||||
else
|
||||
machine_type=${{ matrix.machine_type }}
|
||||
|
@ -288,7 +288,7 @@ Keywords: Music understanding, Music generation
|
||||
|
||||
## [dalle-flow](https://github.com/jina-ai/dalle-flow)
|
||||
|
||||
DALL·E Flow is an interactive workflow for generating high-definition images from a text prompt. Itt leverages DALL·E-Mega, GLID-3 XL, and Stable Diffusion to generate image candidates, and then calls CLIP-as-service to rank the candidates w.r.t. the prompt.
|
||||
DALL·E Flow is an interactive workflow for generating high-definition images from a text prompt. It leverages DALL·E-Mega, GLID-3 XL, and Stable Diffusion to generate image candidates, and then calls CLIP-as-service to rank the candidates w.r.t. the prompt.
|
||||
The preferred candidate is fed to GLID-3 XL for diffusion, which often enriches the texture and background. Finally, the candidate is upscaled to 1024x1024 via SwinIR.
|
||||
|
||||
Keywords: High-definition image generation, Stable Diffusion, DALL-E Mega, GLID-3 XL, CLIP, SwinIR
|
||||
@ -526,7 +526,7 @@ Keywords: Model deployment, CLoud, Mobile, Edge
|
||||
|
||||
## [underthesea](https://github.com/undertheseanlp/underthesea)
|
||||
|
||||
[underthesea](https://github.com/undertheseanlp/underthesea) is a Vietnamese NLP toolkit. Underthesea is a suite of open source Python modules data sets and tutorials supporting research and development in Vietnamese Natural Language Processing. We provides extremely easy API to quickly apply pretrained NLP models to your Vietnamese text, such as word segmentation, part-of-speech tagging (PoS), named entity recognition (NER), text classification and dependency parsing.
|
||||
[underthesea](https://github.com/undertheseanlp/underthesea) is a Vietnamese NLP toolkit. Underthesea is a suite of open source Python modules data sets and tutorials supporting research and development in Vietnamese Natural Language Processing. We provide extremely easy API to quickly apply pretrained NLP models to your Vietnamese text, such as word segmentation, part-of-speech tagging (PoS), named entity recognition (NER), text classification and dependency parsing.
|
||||
|
||||
Keywords: Vietnamese, NLP
|
||||
|
||||
|
@ -56,7 +56,7 @@ Create a [`ImageTextToTextPipeline`] and pass the chat to it. For large models,
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda", torch_dtype=torch.float16)
|
||||
pipeline = pipeline("image-text-to-text", model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device_map="auto", torch_dtype=torch.float16)
|
||||
pipeline(text=messages, max_new_tokens=50, return_full_text=False)
|
||||
[{'input_text': [{'role': 'system',
|
||||
'content': [{'type': 'text',
|
||||
@ -175,7 +175,7 @@ processed_chat = processor.apply_chat_template(
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
video_fps=32,
|
||||
video_fps=16,
|
||||
video_load_backend="decord",
|
||||
)
|
||||
print(processed_chat.keys())
|
||||
|
@ -27,6 +27,9 @@ This guide shows you how to quickly start chatting with Transformers from the co
|
||||
|
||||
## transformers CLI
|
||||
|
||||
|
||||
### Interactive chat session
|
||||
|
||||
After you've [installed Transformers](./installation.md), chat with a model directly from the command line as shown below. It launches an interactive session with a model, with a few base commands listed at the start of the session.
|
||||
|
||||
```bash
|
||||
@ -51,6 +54,68 @@ transformers chat -h
|
||||
|
||||
The chat is implemented on top of the [AutoClass](./model_doc/auto), using tooling from [text generation](./llm_tutorial) and [chat](./chat_templating).
|
||||
|
||||
|
||||
### Serving a model and using MCP tools
|
||||
|
||||
> [!WARNING]
|
||||
> This section is experimental and subject to changes in future versions
|
||||
|
||||
Powering the `chat` interface, we have a server that takes user messages and returns completions. The server has a chat completion API compatible with the OpenAI SDK, so you can also quickly experiment with `transformers` models on existing aplications. To launch a server separately, use the `transformers serve` CLI:
|
||||
|
||||
```bash
|
||||
transformers serve Menlo/Jan-nano
|
||||
```
|
||||
|
||||
Under the hood, the `chat` CLI launches and uses `transformers serve`. This server is also an MCP client, which can receive information available MCP servers (i.e. tools), massage their information into the model prompt, and prepare calls to these tools when the model commands to do so. Naturally, this requires a model that is trained to use tools.
|
||||
|
||||
At the moment, MCP tool usage in `transformers` has the following constraints:
|
||||
- `chat` can't handle tools, but the [`tiny-agents`](https://huggingface.co/blog/python-tiny-agents) CLI can;
|
||||
- Only the `qwen` family of models is supported.
|
||||
|
||||
The first step to use MCP tools is to let the model know which tools are available. As an example, let's consider a `tiny-agents` configuration file with a reference to an [image generation MCP server](https://evalstate-flux1-schnell.hf.space/).
|
||||
|
||||
> [!TIP]
|
||||
> Many Hugging Face Spaces can be used as MCP servers. You can find all compatible Spaces [here](https://huggingface.co/spaces?filter=mcp-server).
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "http://localhost:8000",
|
||||
"provider": "local",
|
||||
"servers": [
|
||||
{
|
||||
"type": "sse",
|
||||
"config": {
|
||||
"url": "https://evalstate-flux1-schnell.hf.space/gradio_api/mcp/sse"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
You can then launch your `tiny-agents` chat interface with the following command.
|
||||
|
||||
```bash
|
||||
tiny-agents run path/to/your/config.json
|
||||
```
|
||||
|
||||
If you have a server (from `transformers serve`) running in the background, you're ready to use MCP tools from a local model! For instance, here's the example of a chat session:
|
||||
|
||||
```bash
|
||||
Agent loaded with 1 tools:
|
||||
• flux1_schnell_infer
|
||||
» Generate an image of a cat on the moon
|
||||
<Tool req_0_tool_call>flux1_schnell_infer {"prompt": "a cat on the moon", "seed": 42, "randomize_seed": true, "width": 1024, "height": 1024, "num_inference_steps": 4}
|
||||
|
||||
Tool req_0_tool_call
|
||||
[Binary Content: Image image/webp, 57732 bytes]
|
||||
The task is complete and the content accessible to the User
|
||||
Image URL: https://evalstate-flux1-schnell.hf.space/gradio_api/file=/tmp/gradio/3dbddc0e53b5a865ed56a4e3dbdd30f3f61cf3b8aabf1b456f43e5241bd968b8/image.webp
|
||||
380576952
|
||||
|
||||
I have generated an image of a cat on the moon using the Flux 1 Schnell Image Generator. The image is 1024x1024 pixels and was created with 4 inference steps. Let me know if you would like to make any changes or need further assistance!
|
||||
```
|
||||
|
||||
|
||||
## TextGenerationPipeline
|
||||
|
||||
[`TextGenerationPipeline`] is a high-level text generation class with a "chat mode". Chat mode is enabled when a conversational model is detected and the chat prompt is [properly formatted](./llm_tutorial#wrong-prompt-format).
|
||||
|
@ -26,6 +26,7 @@ Pass the audio signal, typically stored in `array`, to the feature extractor and
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
|
||||
dataset = load_dataset("PolyAI/minds14", name="en-US", split="train")
|
||||
processed_sample = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=16000)
|
||||
processed_sample
|
||||
{'input_values': [array([ 9.4472744e-05, 3.0777880e-03, -2.8888427e-03, ...,
|
||||
|
@ -14,59 +14,123 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# BigBirdPegasus
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Overview
|
||||
# BigBirdPegasus
|
||||
|
||||
The BigBird model was proposed in [Big Bird: Transformers for Longer Sequences](https://huggingface.co/papers/2007.14062) by
|
||||
Zaheer, Manzil and Guruganesh, Guru and Dubey, Kumar Avinava and Ainslie, Joshua and Alberti, Chris and Ontanon,
|
||||
Santiago and Pham, Philip and Ravula, Anirudh and Wang, Qifan and Yang, Li and others. BigBird, is a sparse-attention
|
||||
based transformer which extends Transformer based models, such as BERT to much longer sequences. In addition to sparse
|
||||
attention, BigBird also applies global attention as well as random attention to the input sequence. Theoretically, it
|
||||
has been shown that applying sparse, global, and random attention approximates full attention, while being
|
||||
computationally much more efficient for longer sequences. As a consequence of the capability to handle longer context,
|
||||
BigBird has shown improved performance on various long document NLP tasks, such as question answering and
|
||||
summarization, compared to BERT or RoBERTa.
|
||||
[BigBirdPegasus](https://huggingface.co/papers/2007.14062) is an encoder-decoder (sequence-to-sequence) transformer model for long-input summarization. It extends the [BigBird](./big_bird) architecture with an additional pretraining objective borrowed from [Pegasus](./pegasus) called gap sequence generation (GSG). Whole sentences are masked and the model has to fill in the gaps in the document. BigBirdPegasus's ability to keep track of long contexts makes it effective at summarizing lengthy inputs, surpassing the performance of base Pegasus models.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
You can find all the original BigBirdPegasus checkpoints under the [Google](https://huggingface.co/google/models?search=bigbird-pegasus) organization.
|
||||
|
||||
*Transformers-based models, such as BERT, have been one of the most successful deep learning models for NLP.
|
||||
Unfortunately, one of their core limitations is the quadratic dependency (mainly in terms of memory) on the sequence
|
||||
length due to their full attention mechanism. To remedy this, we propose, BigBird, a sparse attention mechanism that
|
||||
reduces this quadratic dependency to linear. We show that BigBird is a universal approximator of sequence functions and
|
||||
is Turing complete, thereby preserving these properties of the quadratic, full attention model. Along the way, our
|
||||
theoretical analysis reveals some of the benefits of having O(1) global tokens (such as CLS), that attend to the entire
|
||||
sequence as part of the sparse attention mechanism. The proposed sparse attention can handle sequences of length up to
|
||||
8x of what was previously possible using similar hardware. As a consequence of the capability to handle longer context,
|
||||
BigBird drastically improves performance on various NLP tasks such as question answering and summarization. We also
|
||||
propose novel applications to genomics data.*
|
||||
> [!TIP]
|
||||
> This model was contributed by [vasudevgupta](https://huggingface.co/vasudevgupta).
|
||||
>
|
||||
> Click on the BigBirdPegasus models in the right sidebar for more examples of how to apply BigBirdPegasus to different language tasks.
|
||||
|
||||
The original code can be found [here](https://github.com/google-research/bigbird).
|
||||
The example below demonstrates how to summarize text with [`Pipeline`], [`AutoModel`], and from the command line.
|
||||
|
||||
## Usage tips
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="Pipeline">
|
||||
|
||||
- For an in-detail explanation on how BigBird's attention works, see [this blog post](https://huggingface.co/blog/big-bird).
|
||||
- BigBird comes with 2 implementations: **original_full** & **block_sparse**. For the sequence length < 1024, using
|
||||
**original_full** is advised as there is no benefit in using **block_sparse** attention.
|
||||
- The code currently uses window size of 3 blocks and 2 global blocks.
|
||||
- Sequence length must be divisible by block size.
|
||||
- Current implementation supports only **ITC**.
|
||||
- Current implementation doesn't support **num_random_blocks = 0**.
|
||||
- BigBirdPegasus uses the [PegasusTokenizer](https://github.com/huggingface/transformers/blob/main/src/transformers/models/pegasus/tokenization_pegasus.py).
|
||||
- BigBird is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
|
||||
the left.
|
||||
```py
|
||||
import torch
|
||||
from transformers import pipeline
|
||||
|
||||
pipeline = pipeline(
|
||||
task="summarization",
|
||||
model="google/bigbird-pegasus-large-arxiv",
|
||||
torch_dtype=torch.float32,
|
||||
device=0
|
||||
)
|
||||
pipeline("""Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet.
|
||||
Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts. In the presence of light, plants absorb carbon dioxide from the atmosphere through small pores in their leaves called stomata, and take in water from the soil through their root systems.
|
||||
These ingredients are then transformed into glucose, a type of sugar that serves as a source of chemical energy, and oxygen, which is released as a byproduct into the atmosphere. The glucose produced during photosynthesis is not just used immediately; plants also store it as starch or convert it into other organic compounds like cellulose, which is essential for building their cellular structure.
|
||||
This energy reserve allows them to grow, develop leaves, produce flowers, bear fruit, and carry out various physiological processes throughout their lifecycle.""")
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="AutoModel">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv"
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
input_text = """Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet.
|
||||
Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts. In the presence of light, plants absorb carbon dioxide from the atmosphere through small pores in their leaves called stomata, and take in water from the soil through their root systems.
|
||||
These ingredients are then transformed into glucose, a type of sugar that serves as a source of chemical energy, and oxygen, which is released as a byproduct into the atmosphere. The glucose produced during photosynthesis is not just used immediately; plants also store it as starch or convert it into other organic compounds like cellulose, which is essential for building their cellular structure.
|
||||
This energy reserve allows them to grow, develop leaves, produce flowers, bear fruit, and carry out various physiological processes throughout their lifecycle."""
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="transformers-cli">
|
||||
|
||||
```bash
|
||||
echo -e "Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet. Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts." | transformers-cli run --task summarization --model google/bigbird-pegasus-large-arxiv --device 0
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.
|
||||
|
||||
The example below uses [bitsandbytes](../quantization/bitsandbytes) to only quantize the weights to int4.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_type="nf4"
|
||||
)
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"google/bigbird-pegasus-large-arxiv"
|
||||
)
|
||||
|
||||
input_text = """Plants are among the most remarkable and essential life forms on Earth, possessing a unique ability to produce their own food through a process known as photosynthesis. This complex biochemical process is fundamental not only to plant life but to virtually all life on the planet.
|
||||
Through photosynthesis, plants capture energy from sunlight using a green pigment called chlorophyll, which is located in specialized cell structures called chloroplasts. In the presence of light, plants absorb carbon dioxide from the atmosphere through small pores in their leaves called stomata, and take in water from the soil through their root systems.
|
||||
These ingredients are then transformed into glucose, a type of sugar that serves as a source of chemical energy, and oxygen, which is released as a byproduct into the atmosphere. The glucose produced during photosynthesis is not just used immediately; plants also store it as starch or convert it into other organic compounds like cellulose, which is essential for building their cellular structure.
|
||||
This energy reserve allows them to grow, develop leaves, produce flowers, bear fruit, and carry out various physiological processes throughout their lifecycle."""
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
output = model.generate(**input_ids, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- BigBirdPegasus also uses the [`PegasusTokenizer`].
|
||||
- Inputs should be padded on the right because BigBird uses absolute position embeddings.
|
||||
- BigBirdPegasus supports `original_full` and `block_sparse` attention. If the input sequence length is less than 1024, it is recommended to use `original_full` since sparse patterns don't offer much benefit for smaller inputs.
|
||||
- The current implementation uses window size of 3 blocks and 2 global blocks, only supports the ITC-implementation, and doesn't support `num_random_blocks=0`.
|
||||
- The sequence length must be divisible by the block size.
|
||||
|
||||
## Resources
|
||||
|
||||
- [Text classification task guide](../tasks/sequence_classification)
|
||||
- [Question answering task guide](../tasks/question_answering)
|
||||
- [Causal language modeling task guide](../tasks/language_modeling)
|
||||
- [Translation task guide](../tasks/translation)
|
||||
- [Summarization task guide](../tasks/summarization)
|
||||
Read the [Understanding BigBird's Block Sparse Attention](https://huggingface.co/blog/big-bird) blog post for more details about how BigBird's attention works.
|
||||
|
||||
## BigBirdPegasusConfig
|
||||
|
||||
|
@ -32,8 +32,8 @@ this model, including [Alternating Updates][altup] (AltUp), [Learned Augmented R
|
||||
[MatFormer][matformer], Per-Layer Embeddings (PLE), activation sparsity, and KV cache sharing. The language model uses
|
||||
a similar attention pattern to [Gemma 3](./gemma3.md) with alternating 4 local sliding window self-attention layers for
|
||||
every global self-attention layer with a maximum context length of 32k tokens. Gemma 3n introduces
|
||||
[MobileNet v5][mobilenetv5] as the vision encoder, using a default resolution of 768x768 pixels, and adds a
|
||||
[Universal Speech Model][usm] (USM) as the audio encoder.
|
||||
[MobileNet v5][mobilenetv5] as the vision encoder, using a default resolution of 768x768 pixels, and adds a newly
|
||||
trained audio encoder based on the [Universal Speech Model][usm] (USM) architecture.
|
||||
|
||||
The instruction-tuned variant was post-trained with knowledge distillation and reinforcement learning.
|
||||
|
||||
|
@ -15,9 +15,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Distributed inference
|
||||
|
||||
When a model doesn't fit on a single GPU, distributed inference with [tensor parallelism](./perf_train_gpu_many#tensor-parallelism) can help. Tensor parallelism shards a model onto multiple GPUs and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each GPU can process a tensor slice.
|
||||
When a model doesn't fit on a single GPU, distributed inference with [tensor parallelism](./perf_train_gpu_many#tensor-parallelism) can help. Tensor parallelism shards a model onto multiple accelerators (CUDA GPU, Intel XPU, etc.) and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each accelerator can process a tensor slice.
|
||||
|
||||
However, tensor parallelism adds communication overhead and should be used on single machine setups with multiple GPUs to take advantage of fast intra-node communication. For multi-node training, it may be more efficient to use pipeline or data parallelism depending on your use case.
|
||||
However, tensor parallelism adds communication overhead and should be used on single machine setups with multiple accelerators to take advantage of fast intra-node communication. For multi-node training, it may be more efficient to use pipeline or data parallelism depending on your use case.
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) section on tensor parallelism to learn more.
|
||||
@ -308,4 +308,4 @@ The most important part of DTensor is the `placement` attribute because it tells
|
||||
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # Replicate bias across all GPUs
|
||||
```
|
||||
|
||||
- `Partial()` - Indicates a tensor is pending a reduction operation (not typically relevant for usage in Transformers).
|
||||
- `Partial()` - Indicates a tensor is pending a reduction operation (not typically relevant for usage in Transformers).
|
||||
|
@ -47,7 +47,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device.type)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
@ -49,6 +49,7 @@ Check the table below to see if your hardware is compatible.
|
||||
| Component | Compatibility |
|
||||
|----------|----------------|
|
||||
| CUDA Versions | ✅ cu118, cu126, cu128 |
|
||||
| XPU Versions | ✅ pytorch2.8 |
|
||||
| CPU | ✅ change `device_map="cpu"` (see examples below) |
|
||||
|
||||
|
||||
@ -278,6 +279,71 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Intel XPU
|
||||
<hfoptions id="examples-Intel-XPU">
|
||||
<hfoption id="int8-dynamic-and-weight-only">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig
|
||||
|
||||
quant_config = Int8DynamicActivationInt8WeightConfig()
|
||||
# or int8 weight only quantization
|
||||
# quant_config = Int8WeightOnlyConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("xpu")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="int4-weight-only">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
from torchao.dtypes import Int4XPULayout
|
||||
from torchao.quantization.quant_primitives import ZeroPointDomain
|
||||
|
||||
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4XPULayout(), zero_point_domain=ZeroPointDomain.INT)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("xpu")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### CPU
|
||||
<hfoptions id="examples-CPU">
|
||||
<hfoption id="int8-dynamic-and-weight-only">
|
||||
@ -363,7 +429,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Manual Testing
|
||||
prompt = "Hey, are you conscious? Can you talk to me?"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(quantized_model.device.type)
|
||||
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
|
||||
output_text = tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
@ -434,7 +500,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(quantized_model.device.type)
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
@ -474,7 +540,7 @@ tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
|
||||
|
||||
## Loading quantized models
|
||||
|
||||
Loading a quantized model depends on the quantization scheme. For quantization schemes, like int8 and float8, you can quantize the model on any device and also load it on any device. The example below demonstrates quantizing a model on the CPU and then loading it on CUDA.
|
||||
Loading a quantized model depends on the quantization scheme. For quantization schemes, like int8 and float8, you can quantize the model on any device and also load it on any device. The example below demonstrates quantizing a model on the CPU and then loading it on CUDA or XPU.
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
@ -491,7 +557,7 @@ quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
# save the quantized model
|
||||
output_dir = "llama-3.1-8b-torchao-int8-cuda"
|
||||
output_dir = "llama-3.1-8b-torchao-int8"
|
||||
quantized_model.save_pretrained(output_dir, safe_serialization=False)
|
||||
|
||||
# reload the quantized model
|
||||
@ -502,7 +568,7 @@ reloaded_model = AutoModelForCausalLM.from_pretrained(
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to(reloaded_model.device.type)
|
||||
|
||||
output = reloaded_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
2
setup.py
2
setup.py
@ -148,7 +148,7 @@ _deps = [
|
||||
"protobuf",
|
||||
"psutil",
|
||||
"pyyaml>=5.1",
|
||||
"pydantic",
|
||||
"pydantic>=2",
|
||||
"pytest>=7.2.0",
|
||||
"pytest-asyncio",
|
||||
"pytest-rerunfailures",
|
||||
|
@ -13,33 +13,30 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
import warnings
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
import yaml
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
from huggingface_hub import AsyncInferenceClient, ChatCompletionStreamOutput
|
||||
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
TextIteratorStreamer,
|
||||
logging,
|
||||
)
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.commands.serving import ServeArguments, ServeCommand
|
||||
from transformers.utils import is_rich_available, is_torch_available
|
||||
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
if platform.system() != "Windows":
|
||||
import pwd
|
||||
@ -52,8 +49,12 @@ if is_rich_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedModel
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GenerationConfig,
|
||||
)
|
||||
|
||||
ALLOWED_KEY_CHARS = set(string.ascii_letters + string.whitespace)
|
||||
ALLOWED_VALUE_CHARS = set(
|
||||
@ -107,19 +108,6 @@ If you're a new user, check this basic flag guide: https://huggingface.co/docs/t
|
||||
- **!exit**: closes the interface
|
||||
"""
|
||||
|
||||
# format: (optional CLI arg being deprecated, its current default, corresponding `generate` flag)
|
||||
_DEPRECATION_MAP = [
|
||||
("max_new_tokens", 256, "max_new_tokens"),
|
||||
("do_sample", True, "do_sample"),
|
||||
("num_beams", 1, "num_beams"),
|
||||
("temperature", 1.0, "temperature"),
|
||||
("top_k", 50, "top_k"),
|
||||
("top_p", 1.0, "top_p"),
|
||||
("repetition_penalty", 1.0, "repetition_penalty"),
|
||||
("eos_tokens", None, "eos_token_id"),
|
||||
("eos_token_ids", None, "eos_token_id"),
|
||||
]
|
||||
|
||||
|
||||
class RichInterface:
|
||||
def __init__(self, model_name: Optional[str] = None, user_name: Optional[str] = None):
|
||||
@ -133,21 +121,21 @@ class RichInterface:
|
||||
else:
|
||||
self.user_name = user_name
|
||||
|
||||
def stream_output(self, output_stream: TextIteratorStreamer) -> str:
|
||||
"""Stream output from a role, and return the generated text after it's done steaming."""
|
||||
# This method is originally from the FastChat CLI:
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
|
||||
# Create a Live context for updating the console output
|
||||
text = ""
|
||||
async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) -> tuple[str, int]:
|
||||
self._console.print(f"[bold blue]<{self.model_name}>:")
|
||||
with Live(console=self._console, refresh_per_second=4) as live:
|
||||
# Read lines from the stream
|
||||
for i, outputs in enumerate(output_stream):
|
||||
if not outputs or i == 0:
|
||||
text = ""
|
||||
async for token in await stream:
|
||||
outputs = token.choices[0].delta.content
|
||||
request_id = token.id
|
||||
|
||||
if not outputs:
|
||||
continue
|
||||
|
||||
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
|
||||
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
|
||||
outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
|
||||
|
||||
text += outputs
|
||||
# Render the accumulated text as Markdown
|
||||
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||
@ -160,6 +148,7 @@ class RichInterface:
|
||||
# introduce trailing spaces (only) in code block, but it works well
|
||||
# especially for console output, because in general the console does not
|
||||
# care about trailing spaces.
|
||||
|
||||
lines = []
|
||||
for line in text.splitlines():
|
||||
lines.append(line)
|
||||
@ -169,11 +158,15 @@ class RichInterface:
|
||||
lines.append("\n")
|
||||
else:
|
||||
lines.append(" \n")
|
||||
|
||||
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
|
||||
|
||||
# Update the Live console output
|
||||
live.update(markdown)
|
||||
live.update(markdown, refresh=True)
|
||||
|
||||
self._console.print()
|
||||
return text
|
||||
|
||||
return text, request_id
|
||||
|
||||
def input(self) -> str:
|
||||
"""Gets user input from the console."""
|
||||
@ -245,25 +238,6 @@ class ChatArguments:
|
||||
),
|
||||
},
|
||||
)
|
||||
# Deprecated CLI args start here
|
||||
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate."})
|
||||
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation."})
|
||||
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search."})
|
||||
temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation."})
|
||||
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling."})
|
||||
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling."})
|
||||
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty."})
|
||||
eos_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "EOS tokens (text format) to stop the generation. If multiple they should be comma separated."
|
||||
},
|
||||
)
|
||||
eos_token_ids: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated."},
|
||||
)
|
||||
# Deprecated CLI args end here
|
||||
|
||||
# Model loading
|
||||
model_revision: str = field(
|
||||
@ -300,6 +274,10 @@ class ChatArguments:
|
||||
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
|
||||
|
||||
# Serving settings
|
||||
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to.."})
|
||||
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
|
||||
|
||||
|
||||
def chat_command_factory(args: Namespace):
|
||||
"""
|
||||
@ -322,7 +300,10 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
group = chat_parser.add_argument_group("Positional arguments")
|
||||
group.add_argument(
|
||||
"model_name_or_path_positional", type=str, default=None, help="Name of the pre-trained model."
|
||||
"model_name_or_path_or_address",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the pre-trained model or address to connect to.",
|
||||
)
|
||||
group.add_argument(
|
||||
"generate_flags",
|
||||
@ -332,57 +313,45 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
"Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
|
||||
"and lists of integers, more advanced parameterization should be set through --generation-config. "
|
||||
"Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. "
|
||||
"If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options"
|
||||
"If you're a new user, check this basic flag guide: "
|
||||
"https://huggingface.co/docs/transformers/llm_tutorial#common-options"
|
||||
),
|
||||
nargs="*",
|
||||
)
|
||||
chat_parser.set_defaults(func=chat_command_factory)
|
||||
|
||||
def __init__(self, args):
|
||||
args = self._handle_deprecated_args(args)
|
||||
if args.model_name_or_path_or_address is not None:
|
||||
name = args.model_name_or_path_or_address
|
||||
if name.startswith("http") or name.startswith("https") or name.startswith("localhost"):
|
||||
self.spawn_backend = False
|
||||
|
||||
if args.host != "localhost" or args.port != 8000:
|
||||
raise ValueError(
|
||||
"Looks like you’ve set both a server address and a custom host/port. "
|
||||
"Please pick just one way to specify the server."
|
||||
)
|
||||
|
||||
args.host, args.port = args.model_name_or_path_or_address.rsplit(":", 1)
|
||||
else:
|
||||
self.spawn_backend = True
|
||||
args.model_name_or_path = args.model_name_or_path_or_address
|
||||
|
||||
if not is_rich_available() and (not is_torch_available() and self.spawn_backend):
|
||||
raise ImportError(
|
||||
"You need to install rich to use the chat interface. Additionally, you have not specified a remote "
|
||||
"endpoint and are therefore spawning a backend. Torch is required for this: (`pip install rich torch`)"
|
||||
)
|
||||
elif not is_rich_available():
|
||||
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
|
||||
elif not is_torch_available() and self.spawn_backend:
|
||||
raise ImportError(
|
||||
"You have not specified a remote endpoint and are therefore spawning a backend. Torch is required "
|
||||
"for this: (`pip install rich torch`)"
|
||||
)
|
||||
|
||||
self.args = args
|
||||
|
||||
def _handle_deprecated_args(self, args: ChatArguments) -> ChatArguments:
|
||||
"""
|
||||
Handles deprecated arguments and their deprecation cycle. To be removed after we fully migrated to the new
|
||||
args.
|
||||
"""
|
||||
has_warnings = False
|
||||
|
||||
# 1. Model as a positional argument
|
||||
args.model_name_or_path_positional = args.model_name_or_path_positional or args.model_name_or_path
|
||||
if args.model_name_or_path_positional is None:
|
||||
raise ValueError(
|
||||
"One of the following must be provided:"
|
||||
"\n- The positional argument containing the model repo, e.g. `transformers chat <model_repo>`"
|
||||
"\n- the optional --model_name_or_path argument, containing the model repo (deprecated)"
|
||||
)
|
||||
elif args.model_name_or_path is not None:
|
||||
has_warnings = True
|
||||
warnings.warn(
|
||||
"The --model_name_or_path argument is deprecated will be removed in v4.54.0. Use the positional "
|
||||
"argument instead, e.g. `transformers chat <model_repo>`.",
|
||||
FutureWarning,
|
||||
)
|
||||
# 2. Named generate option args
|
||||
for deprecated_arg, default_value, new_arg in _DEPRECATION_MAP:
|
||||
value = getattr(args, deprecated_arg)
|
||||
if value != default_value:
|
||||
has_warnings = True
|
||||
warnings.warn(
|
||||
f"The --{deprecated_arg} argument is deprecated will be removed in v4.54.0. There are two "
|
||||
"alternative solutions to specify this generation option: \n"
|
||||
"1. Pass `--generation-config <path_to_file/Hub repo>` to specify a generation config.\n"
|
||||
"2. Pass `generate` flags through positional arguments, e.g. `transformers chat <model_repo> "
|
||||
f"{new_arg}={value}`",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if has_warnings:
|
||||
print("\n(Press enter to continue)")
|
||||
input()
|
||||
return args
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Chat session methods
|
||||
@staticmethod
|
||||
@ -404,7 +373,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
if filename is None:
|
||||
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
filename = f"{args.model_name_or_path_positional}/chat_{time_str}.json"
|
||||
filename = f"{args.model_name_or_path_or_address}/chat_{time_str}.json"
|
||||
filename = os.path.join(folder, filename)
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
@ -477,40 +446,23 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
return processed_generate_flags
|
||||
|
||||
def get_generation_parameterization(
|
||||
self, args: ChatArguments, tokenizer: AutoTokenizer, model: PreTrainedModel
|
||||
) -> tuple[GenerationConfig, dict]:
|
||||
def get_generation_parameterization(self, args: ChatArguments) -> tuple[GenerationConfig, dict]:
|
||||
"""
|
||||
Returns a GenerationConfig object holding the generation parameters for the CLI command.
|
||||
"""
|
||||
# No generation config arg provided -> use default generation config, apply CLI defaults
|
||||
if args.generation_config is None:
|
||||
# We start off from the checkpoint's generation config
|
||||
generation_config = copy.deepcopy(model.generation_config)
|
||||
# Apply deprecated CLI args on top of the default generation config
|
||||
pad_token_id, eos_token_ids = self.parse_eos_tokens(
|
||||
tokenizer, generation_config, args.eos_tokens, args.eos_token_ids
|
||||
)
|
||||
deprecated_kwargs = {
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"do_sample": args.do_sample,
|
||||
"num_beams": args.num_beams,
|
||||
"temperature": args.temperature,
|
||||
"top_k": args.top_k,
|
||||
"top_p": args.top_p,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
"pad_token_id": pad_token_id,
|
||||
"eos_token_id": eos_token_ids,
|
||||
}
|
||||
generation_config.update(**deprecated_kwargs)
|
||||
# generation config arg provided -> use it as the base parameterization
|
||||
else:
|
||||
# No generation config arg provided -> use base generation config, apply CLI defaults
|
||||
if args.generation_config is not None:
|
||||
if ".json" in args.generation_config: # is a local file
|
||||
dirname = os.path.dirname(args.generation_config)
|
||||
filename = os.path.basename(args.generation_config)
|
||||
generation_config = GenerationConfig.from_pretrained(dirname, filename)
|
||||
else:
|
||||
generation_config = GenerationConfig.from_pretrained(args.generation_config)
|
||||
else:
|
||||
# !!!!!!!!!
|
||||
# This is a chat session, so we have a few non-standard defaults
|
||||
# !!!!!!!!!
|
||||
generation_config = GenerationConfig(do_sample=True, max_new_tokens=256)
|
||||
|
||||
# Finally: parse and apply `generate_flags`
|
||||
parsed_generate_flags = self.parse_generate_flags(args.generate_flags)
|
||||
@ -664,7 +616,7 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
|
||||
elif user_input == "!status":
|
||||
interface.print_status(
|
||||
model_name=args.model_name_or_path_positional,
|
||||
model_name=args.model_name_or_path,
|
||||
generation_config=generation_config,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
@ -679,10 +631,32 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
# -----------------------------------------------------------------------------------------------------------------
|
||||
# Main logic
|
||||
def run(self):
|
||||
if not is_rich_available():
|
||||
raise ImportError("You need to install rich to use the chat interface. (`pip install rich`)")
|
||||
if not is_torch_available():
|
||||
raise ImportError("You need to install torch to use the chat interface. (`pip install torch`)")
|
||||
asyncio.run(self._inner_run())
|
||||
|
||||
async def _inner_run(self):
|
||||
if self.spawn_backend:
|
||||
serve_args = ServeArguments(
|
||||
device=self.args.device,
|
||||
torch_dtype=self.args.torch_dtype,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
attn_implementation=self.args.attn_implementation,
|
||||
load_in_8bit=self.args.load_in_8bit,
|
||||
load_in_4bit=self.args.load_in_4bit,
|
||||
bnb_4bit_quant_type=self.args.bnb_4bit_quant_type,
|
||||
use_bnb_nested_quant=self.args.use_bnb_nested_quant,
|
||||
host=self.args.host,
|
||||
port=self.args.port,
|
||||
log_level="error",
|
||||
)
|
||||
serve_command = ServeCommand(serve_args)
|
||||
|
||||
thread = Thread(target=serve_command.run)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
model = self.args.model_name_or_path + "@" + self.args.model_revision
|
||||
host = "http://localhost" if self.args.host == "localhost" else self.args.host
|
||||
client = AsyncInferenceClient(f"{host}:{self.args.port}")
|
||||
|
||||
args = self.args
|
||||
if args.examples_path is None:
|
||||
@ -696,19 +670,14 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
else:
|
||||
user = args.user
|
||||
|
||||
model, tokenizer = self.load_model_and_tokenizer(args)
|
||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
generation_config, model_kwargs = self.get_generation_parameterization(args, tokenizer, model)
|
||||
generation_config, model_kwargs = self.get_generation_parameterization(args)
|
||||
|
||||
# if not verbose -> disable warnings, progress bars, etc in the chat interface
|
||||
if not args.verbose:
|
||||
logging.set_verbosity_error()
|
||||
disable_progress_bars()
|
||||
|
||||
interface = RichInterface(model_name=args.model_name_or_path_positional, user_name=user)
|
||||
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
||||
interface.clear()
|
||||
chat = self.clear_chat_history(args.system_prompt)
|
||||
|
||||
request_id = None
|
||||
|
||||
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
|
||||
interface.print_help(minimal=True)
|
||||
while True:
|
||||
@ -736,23 +705,29 @@ class ChatCommand(BaseTransformersCLICommand):
|
||||
else:
|
||||
chat.append({"role": "user", "content": user_input})
|
||||
|
||||
inputs = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
stream = client.chat_completion(
|
||||
chat,
|
||||
stream=True,
|
||||
extra_body={
|
||||
"request_id": request_id,
|
||||
"generation_config": {**generation_config.to_dict()},
|
||||
"model": model,
|
||||
},
|
||||
)
|
||||
attention_mask = torch.ones_like(inputs)
|
||||
generation_kwargs = {
|
||||
"inputs": inputs,
|
||||
"attention_mask": attention_mask,
|
||||
"streamer": generation_streamer,
|
||||
"generation_config": generation_config,
|
||||
**model_kwargs,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
model_output = interface.stream_output(generation_streamer)
|
||||
thread.join()
|
||||
model_output, request_id = await interface.stream_output(stream)
|
||||
|
||||
chat.append({"role": "assistant", "content": model_output})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = ChatArguments()
|
||||
args.model_name_or_path_or_address = "meta-llama/Llama-3.2-3b-Instruct"
|
||||
args.model_name_or_path_or_address = "http://localhost:8000"
|
||||
chat = ChatCommand(args)
|
||||
chat.run()
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -11,33 +11,95 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Thread
|
||||
from typing import Any, Optional
|
||||
|
||||
from ..pipelines import Pipeline, get_supported_tasks, pipeline
|
||||
from ..utils import logging
|
||||
from huggingface_hub import (
|
||||
ChatCompletionStreamOutputChoice,
|
||||
ChatCompletionStreamOutputDelta,
|
||||
ChatCompletionStreamOutputDeltaToolCall,
|
||||
ChatCompletionStreamOutputFunction,
|
||||
ModelInfo,
|
||||
model_info,
|
||||
)
|
||||
|
||||
from transformers.utils.import_utils import is_fastapi_available, is_pydantic_available, is_uvicorn_available
|
||||
|
||||
from .. import PreTrainedTokenizerFast, TextIteratorStreamer
|
||||
from ..generation.continuous_batching import ContinuousBatchingManager, RequestStatus
|
||||
from ..utils import is_torch_available, logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
try:
|
||||
from fastapi import Body, FastAPI, HTTPException
|
||||
from fastapi.routing import APIRoute
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
|
||||
|
||||
if is_pydantic_available() and is_fastapi_available() and is_uvicorn_available():
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import JSONResponse
|
||||
from uvicorn import run
|
||||
|
||||
_serve_dependencies_installed = True
|
||||
except (ImportError, AttributeError):
|
||||
BaseModel = object
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def Body(*x, **y):
|
||||
pass
|
||||
class ChatCompletionInput(BaseModel):
|
||||
messages: list[Message]
|
||||
|
||||
_serve_dependencies_installed = False
|
||||
stream: Optional[bool] = False
|
||||
model: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
extra_body: Optional[dict] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[list[float]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[list[str]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
# Additional options supported by the HFH InferenceClient
|
||||
# that aren't yet supported here.
|
||||
|
||||
# logprobs: Optional[bool] = None
|
||||
tools: Any = None
|
||||
# n: Optional[int] = None
|
||||
# presence_penalty: Optional[float] = None
|
||||
# response_format: Optional[ChatCompletionInputGrammarType] = None
|
||||
# stream_options: Optional[ChatCompletionInputStreamOptions] = None
|
||||
# tool_choice: Optional[Union[ChatCompletionInputToolChoiceClass, "ChatCompletionInputToolChoiceEnum"]] = None
|
||||
# tool_prompt: Optional[str] = None
|
||||
# top_logprobs: Optional[int] = None
|
||||
|
||||
|
||||
logger = logging.get_logger("transformers/serving")
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Possible tokens that indicate the start/end of a tool call
|
||||
# TODO (joao, matt): streamline tool token detection logic
|
||||
_TOOL_CALL_TOKENS = {
|
||||
"qwen": {
|
||||
"start": "<tool_call>",
|
||||
"end": "</tool_call>",
|
||||
},
|
||||
}
|
||||
_MODELS_WITH_TOOL_SUPPORT = list(_TOOL_CALL_TOKENS.keys())
|
||||
|
||||
|
||||
def serve_command_factory(args: Namespace):
|
||||
@ -46,50 +108,114 @@ def serve_command_factory(args: Namespace):
|
||||
|
||||
Returns: ServeCommand
|
||||
"""
|
||||
nlp = pipeline(
|
||||
task=args.task,
|
||||
model=args.model if args.model else None,
|
||||
config=args.config,
|
||||
tokenizer=args.tokenizer,
|
||||
device=args.device,
|
||||
return ServeCommand(args)
|
||||
|
||||
|
||||
def create_generation_config_from_req(req: "ChatCompletionInput") -> "GenerationConfig":
|
||||
"""
|
||||
Creates a generation config from the parameters of the request. Note that we can pass a `GenerationConfig`
|
||||
(serialized into a `dict`) in `extra_body`, for full `generate` parameterization.
|
||||
|
||||
Args:
|
||||
req (`ChatCompletionInput`): The request which may optionally contain generation parameters.
|
||||
|
||||
Returns:
|
||||
The prepared `GenerationConfig` object.
|
||||
"""
|
||||
if req.extra_body is not None and "generation_config" in req.extra_body:
|
||||
for key in req.extra_body["generation_config"].keys():
|
||||
if key in ChatCompletionInput.base_field_names.keys():
|
||||
return {"error": "Duplicated key in the root request and in the passed generation config."}
|
||||
|
||||
if req.extra_body is not None and "generation_config" in req.extra_body:
|
||||
generation_config = GenerationConfig(**(req.extra_body["generation_config"]))
|
||||
else:
|
||||
generation_config = GenerationConfig()
|
||||
|
||||
if req.frequency_penalty is not None:
|
||||
generation_config.repetition_penalty = req.frequency_penalty
|
||||
if req.logit_bias is not None:
|
||||
generation_config.sequence_bias = req.logit_bias
|
||||
if req.stop is not None:
|
||||
generation_config.stop_strings = req.stop
|
||||
if req.temperature is not None:
|
||||
generation_config.temperature = req.temperature
|
||||
if req.top_p is not None:
|
||||
generation_config.top_p = req.top_p
|
||||
if req.seed is not None:
|
||||
torch.manual_seed(req.seed)
|
||||
|
||||
return generation_config
|
||||
|
||||
|
||||
class ToolState:
|
||||
"""Lightweight class to keep track of the tool call state."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the tool call state (assumes we're outside a tool call)."""
|
||||
self.inside_tool_call = False
|
||||
self.has_tool_name_defined = False
|
||||
self.arg_nesting_level = 0
|
||||
self.buffer = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServeArguments:
|
||||
r"""
|
||||
Arguments for the serve CLI.
|
||||
|
||||
See the metadata arg for each argument's description -- the metadata will be printed with
|
||||
`transformers serve --help`
|
||||
"""
|
||||
|
||||
device: str = field(default="cpu", metadata={"help": "Device to use for inference."})
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, "
|
||||
"the dtype will be automatically derived from the model's weights.",
|
||||
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||
},
|
||||
)
|
||||
return ServeCommand(nlp, args.host, args.port, args.workers)
|
||||
trust_remote_code: bool = field(
|
||||
default=False, metadata={"help": "Whether to trust remote code when loading a model."}
|
||||
)
|
||||
attn_implementation: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in "
|
||||
"which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
|
||||
},
|
||||
)
|
||||
load_in_8bit: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use 8 bit precision for the base model - works only with LoRA."},
|
||||
)
|
||||
load_in_4bit: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use 4 bit precision for the base model - works only with LoRA."},
|
||||
)
|
||||
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]})
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "Whether to use nested quantization."})
|
||||
|
||||
# Serving settings
|
||||
host: str = field(default="localhost", metadata={"help": "Interface the server will listen to.."})
|
||||
port: int = field(default=8000, metadata={"help": "Port the server will listen to."})
|
||||
|
||||
class ServeModelInfoResult(BaseModel):
|
||||
"""
|
||||
Expose model information
|
||||
"""
|
||||
|
||||
infos: dict
|
||||
|
||||
|
||||
class ServeTokenizeResult(BaseModel):
|
||||
"""
|
||||
Tokenize result model
|
||||
"""
|
||||
|
||||
tokens: list[str]
|
||||
tokens_ids: Optional[list[int]]
|
||||
|
||||
|
||||
class ServeDeTokenizeResult(BaseModel):
|
||||
"""
|
||||
DeTokenize result model
|
||||
"""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
class ServeForwardResult(BaseModel):
|
||||
"""
|
||||
Forward result model
|
||||
"""
|
||||
|
||||
output: Any
|
||||
# Other settings
|
||||
log_level: str = field(
|
||||
default="info", metadata={"help": "Logging level as a string. Example: 'info' or 'warning'."}
|
||||
)
|
||||
|
||||
|
||||
class ServeCommand(BaseTransformersCLICommand):
|
||||
loaded_model: Optional[str] = None
|
||||
model: PreTrainedModel
|
||||
tokenizer: PreTrainedTokenizerFast
|
||||
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
"""
|
||||
@ -98,131 +224,409 @@ class ServeCommand(BaseTransformersCLICommand):
|
||||
Args:
|
||||
parser: Root parser to register command-specific arguments
|
||||
"""
|
||||
serve_parser = parser.add_parser(
|
||||
"serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
|
||||
)
|
||||
serve_parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
choices=get_supported_tasks(),
|
||||
help="The task to run the pipeline on",
|
||||
)
|
||||
serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
|
||||
serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
|
||||
serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
|
||||
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
|
||||
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
|
||||
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
|
||||
serve_parser.add_argument(
|
||||
"--device",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
|
||||
)
|
||||
dataclass_types = (ServeArguments,)
|
||||
serve_parser = parser.add_parser("serve", dataclass_types=dataclass_types)
|
||||
serve_parser.set_defaults(func=serve_command_factory)
|
||||
|
||||
def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
|
||||
self._pipeline = pipeline
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.workers = workers
|
||||
|
||||
if not _serve_dependencies_installed:
|
||||
raise RuntimeError(
|
||||
"Using serve command requires FastAPI and uvicorn. "
|
||||
'Please install transformers with [serving]: pip install "transformers[serving]". '
|
||||
"Or install FastAPI and uvicorn separately."
|
||||
def __init__(self, args: ServeArguments):
|
||||
if not is_pydantic_available() or not is_fastapi_available() or not is_uvicorn_available():
|
||||
raise ImportError(
|
||||
"Missing dependencies for the serving CLI. Please install with `pip install transformers[serving]`"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Serving model over {host}:{port}")
|
||||
self._app = FastAPI(
|
||||
routes=[
|
||||
APIRoute(
|
||||
"/",
|
||||
self.model_info,
|
||||
response_model=ServeModelInfoResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["GET"],
|
||||
|
||||
self.args = args
|
||||
self.use_continuous_batching = self.args.attn_implementation == "sdpa_paged"
|
||||
|
||||
# State: preserves information about the last call and last KV cache, to determine whether we can reuse the KV
|
||||
# cache and avoid re-running prefil
|
||||
self.last_messages = None
|
||||
self.last_kv_cache = None
|
||||
|
||||
transformers_logger = logging.get_logger("transformers")
|
||||
transformers_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||
|
||||
cb_logger = logging.get_logger("transformers.generation.continuous_batching")
|
||||
cb_logger.setLevel(logging.log_levels[self.args.log_level.lower()])
|
||||
|
||||
def build_chunk(
|
||||
self,
|
||||
content: str,
|
||||
request_id: str,
|
||||
role: Optional[str] = None,
|
||||
finish_reason: Optional[str] = None,
|
||||
tool_calls: Optional[list[ChatCompletionStreamOutputDeltaToolCall]] = None,
|
||||
) -> str:
|
||||
payload = {
|
||||
"object": "chat.completion.chunk",
|
||||
"id": request_id,
|
||||
"created": int(time.time()),
|
||||
"model": self.loaded_model,
|
||||
"system_fingerprint": "",
|
||||
"choices": [
|
||||
ChatCompletionStreamOutputChoice(
|
||||
delta=ChatCompletionStreamOutputDelta(
|
||||
role=role,
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
APIRoute(
|
||||
"/tokenize",
|
||||
self.tokenize,
|
||||
response_model=ServeTokenizeResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
APIRoute(
|
||||
"/detokenize",
|
||||
self.detokenize,
|
||||
response_model=ServeDeTokenizeResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
APIRoute(
|
||||
"/forward",
|
||||
self.forward,
|
||||
response_model=ServeForwardResult,
|
||||
response_class=JSONResponse,
|
||||
methods=["POST"],
|
||||
),
|
||||
],
|
||||
timeout=600,
|
||||
)
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
],
|
||||
}
|
||||
return f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
def run(self):
|
||||
run(self._app, host=self.host, port=self.port, workers=self.workers)
|
||||
app = FastAPI()
|
||||
|
||||
def model_info(self):
|
||||
return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
|
||||
if self.use_continuous_batching:
|
||||
self.continuous_batching(app)
|
||||
else:
|
||||
self.generate(app)
|
||||
|
||||
def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_text_gen_models() -> list[ModelInfo]:
|
||||
"""
|
||||
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
|
||||
model working with generate can work.
|
||||
|
||||
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
|
||||
integrations.
|
||||
"""
|
||||
return [
|
||||
model_info("Menlo/Jan-nano"),
|
||||
model_info("Menlo/Jan-nano-128k"),
|
||||
model_info("Qwen/Qwen2.5-0.5B-Instruct"),
|
||||
model_info("Qwen/Qwen2.5-3B-Instruct"),
|
||||
model_info("Qwen/Qwen2.5-7B-Instruct"),
|
||||
model_info("Qwen/Qwen2.5-14B-Instruct"),
|
||||
model_info("meta-llama/Llama-3.1-8B-Instruct"),
|
||||
model_info("meta-llama/Llama-3.2-1B-Instruct"),
|
||||
model_info("meta-llama/Llama-3.3-70B-Instruct"),
|
||||
]
|
||||
|
||||
@app.get("/v1/models")
|
||||
def get_all_models():
|
||||
return JSONResponse(
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model.id,
|
||||
"object": "model",
|
||||
"crated": model.created_at.timestamp(),
|
||||
"owned_by": model.author,
|
||||
}
|
||||
for model in get_text_gen_models()
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
uvicorn.run(app, host=self.args.host, port=self.args.port, log_level=self.args.log_level)
|
||||
|
||||
def continuous_batching(self, app):
|
||||
generation_config = GenerationConfig(
|
||||
eos_token_id=self.tokenizer.eos_token_id,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
use_cache=False,
|
||||
num_blocks=1,
|
||||
block_size=1024,
|
||||
do_sample=False,
|
||||
max_batch_tokens=10,
|
||||
scheduler="fifo",
|
||||
)
|
||||
|
||||
manager: ContinuousBatchingManager = self.model.init_continuous_batching(
|
||||
generation_config=generation_config, streaming=True
|
||||
)
|
||||
manager.start()
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
def _serve(req: "ChatCompletionInput"):
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
update_model = req.model != self.loaded_model
|
||||
|
||||
if update_model:
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||
|
||||
chat = req.messages
|
||||
inputs = self.tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
self.model.device
|
||||
)
|
||||
|
||||
generation_config = create_generation_config_from_req(req)
|
||||
|
||||
def stream_response(_inputs):
|
||||
try:
|
||||
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||
request_id = manager.add_request(_inputs, request_id=req.request_id, max_new_tokens=max_new_tokens)
|
||||
queue_is_flushed = False
|
||||
|
||||
for result in manager:
|
||||
if req.request_id is not None and not queue_is_flushed:
|
||||
if result.status == RequestStatus.FINISHED:
|
||||
continue
|
||||
else:
|
||||
queue_is_flushed = True
|
||||
|
||||
finish_reason = "stop" if result.status == RequestStatus.FINISHED else None
|
||||
yield self.build_chunk(result.next_token, request_id=request_id, finish_reason=finish_reason)
|
||||
|
||||
if result.status == RequestStatus.FINISHED:
|
||||
break
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
yield f'data: {{"error": "{str(e)}"}}'
|
||||
|
||||
return StreamingResponse(stream_response(inputs[0]), media_type="text/event-stream")
|
||||
|
||||
def is_continuation(self, req: "ChatCompletionInput") -> bool:
|
||||
"""
|
||||
Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
|
||||
tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
|
||||
mapping.
|
||||
"""
|
||||
try:
|
||||
tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
|
||||
Determines whether the current request is a continuation of the last request. In other words, if it is the
|
||||
same chat session.
|
||||
|
||||
if return_ids:
|
||||
tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
|
||||
return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
|
||||
Args:
|
||||
req (`ChatCompletionInput`): The request to check.
|
||||
|
||||
Returns:
|
||||
`True` if the request is a continuation of the last request, `False` otherwise.
|
||||
"""
|
||||
req_continues_last_messages = True
|
||||
|
||||
# No cached messages: this is a new request
|
||||
if self.last_messages is None:
|
||||
req_continues_last_messages = False
|
||||
# The new request has fewer rounds of conversation: this is a new request
|
||||
elif len(self.last_messages) > len(req.messages):
|
||||
req_continues_last_messages = False
|
||||
# Otherwise, check that the last messages are a subset of the new request
|
||||
else:
|
||||
for i in range(len(self.last_messages)):
|
||||
if self.last_messages[i] != req.messages[i]:
|
||||
req_continues_last_messages = False
|
||||
break
|
||||
|
||||
self.last_messages = req.messages
|
||||
return req_continues_last_messages
|
||||
|
||||
def generate(self, app):
|
||||
@app.post("/v1/chat/completions")
|
||||
def _serve(req: "ChatCompletionInput"):
|
||||
update_model = req.model != self.loaded_model
|
||||
|
||||
if update_model:
|
||||
self.model, self.tokenizer = self.load_model_and_tokenizer(req.model, self.args)
|
||||
|
||||
if not req.stream:
|
||||
return {"error": "Only streaming mode is supported."}
|
||||
|
||||
# HACK for tiny-agents: it sends a request after the assistant message (???). Let's assume we can't have a
|
||||
# request whose last message is from the assistant.
|
||||
if req.messages[-1].role == "assistant":
|
||||
return
|
||||
|
||||
# ====== TOOL PREPROCESSING LOGIC ======
|
||||
tool_model_family = None
|
||||
for supported_model_families in _MODELS_WITH_TOOL_SUPPORT:
|
||||
if supported_model_families in self.model.config.architectures[0].lower():
|
||||
tool_model_family = supported_model_families
|
||||
break
|
||||
# TODO: trigger 2 constrained generations after the tool call start token is emitted:
|
||||
# 1. force generation to pick from the tool names
|
||||
# 2. force generation to pick from that tool's arguments
|
||||
# ====== END OF TOOL PREPROCESSING LOGIC ======
|
||||
|
||||
if tool_model_family is not None:
|
||||
text = self.tokenizer.apply_chat_template(
|
||||
req.messages, add_generation_prompt=True, tokenize=False, tools=req.tools
|
||||
)
|
||||
else:
|
||||
return ServeTokenizeResult(tokens=tokens_txt)
|
||||
text = self.tokenizer.apply_chat_template(req.messages, add_generation_prompt=True, tokenize=False)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)["input_ids"]
|
||||
request_id = req.request_id if req.request_id is not None else "req_0"
|
||||
|
||||
def detokenize(
|
||||
self,
|
||||
tokens_ids: list[int] = Body(None, embed=True),
|
||||
skip_special_tokens: bool = Body(False, embed=True),
|
||||
cleanup_tokenization_spaces: bool = Body(True, embed=True),
|
||||
):
|
||||
"""
|
||||
Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
|
||||
**skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
|
||||
Flag indicating to remove all leading/trailing spaces and intermediate ones.
|
||||
"""
|
||||
try:
|
||||
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
|
||||
return ServeDeTokenizeResult(model="", text=decoded_str)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
|
||||
generation_streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True, skip_prompt=True)
|
||||
|
||||
async def forward(self, inputs=Body(None, embed=True)):
|
||||
"""
|
||||
**inputs**: **attention_mask**: **tokens_type_ids**:
|
||||
"""
|
||||
generation_config = create_generation_config_from_req(req)
|
||||
max_new_tokens = req.max_tokens or generation_config.max_new_tokens or 256
|
||||
generation_config.max_new_tokens = max_new_tokens
|
||||
|
||||
# Check we don't have empty string
|
||||
if len(inputs) == 0:
|
||||
return ServeForwardResult(output=[], attention=[])
|
||||
last_kv_cache = None
|
||||
if self.is_continuation(req) and not update_model:
|
||||
last_kv_cache = self.last_kv_cache
|
||||
|
||||
try:
|
||||
# Forward through the model
|
||||
output = self._pipeline(inputs)
|
||||
return ServeForwardResult(output=output)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, {"error": str(e)})
|
||||
generation_kwargs = {
|
||||
"inputs": inputs,
|
||||
"attention_mask": torch.ones_like(inputs),
|
||||
"streamer": generation_streamer,
|
||||
"generation_config": generation_config,
|
||||
"return_dict_in_generate": True,
|
||||
"past_key_values": last_kv_cache,
|
||||
}
|
||||
|
||||
def stream_response(streamer, _request_id):
|
||||
# Thin wrapper to save the KV cache after generation
|
||||
def generate_with_cache(**kwargs):
|
||||
generate_output = self.model.generate(**kwargs)
|
||||
self.last_kv_cache = generate_output.past_key_values
|
||||
|
||||
thread = Thread(target=generate_with_cache, kwargs=generation_kwargs)
|
||||
|
||||
try:
|
||||
thread.start()
|
||||
tool_state = ToolState()
|
||||
|
||||
for result in streamer:
|
||||
# ====== TOOL CALL LOGIC ======
|
||||
if tool_model_family is not None:
|
||||
# Start of a tool call: reset state variables, set `inside_tool_call`
|
||||
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["start"]:
|
||||
tool_state.inside_tool_call = True
|
||||
continue
|
||||
|
||||
# End of tool call: reset `inside_tool_call`, emit a `finish_reason`
|
||||
if result.strip() == _TOOL_CALL_TOKENS[tool_model_family]["end"]:
|
||||
tool_state.reset()
|
||||
yield self.build_chunk("", _request_id, role=None, finish_reason="tool_calls")
|
||||
continue
|
||||
|
||||
# Inside a tool call
|
||||
if tool_state.inside_tool_call:
|
||||
tool_state.buffer += result
|
||||
|
||||
# First step: extract the tool name (may need several tokens, and we can't emit a delta
|
||||
# until we have the full name)
|
||||
if not tool_state.has_tool_name_defined:
|
||||
tool_name = re.search(r"\"name\": \"(.*?)\"", tool_state.buffer)
|
||||
if tool_name is None:
|
||||
continue
|
||||
else:
|
||||
tool_name = tool_name.group(1)
|
||||
tool_state.has_tool_name_defined = True
|
||||
tool = ChatCompletionStreamOutputDeltaToolCall(
|
||||
function=ChatCompletionStreamOutputFunction(
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
),
|
||||
index=0,
|
||||
type="function",
|
||||
id=_request_id + "_tool_call", # Only the first tool call delta has an id
|
||||
)
|
||||
|
||||
# Second step: extract tool arguments. The tool arguments can be seen as a json string
|
||||
# within the tool json string. We emit a delta for the arguments.
|
||||
else:
|
||||
# Empty text: skip
|
||||
if result == "":
|
||||
continue
|
||||
# Until we see the `"arguments": {` in the buffer, we skip
|
||||
# TODO: other models will likely need more elaborate processing here
|
||||
if '"arguments": {' not in tool_state.buffer:
|
||||
continue
|
||||
|
||||
# Handle nesting. We want to exclude the last } from the emitted arguments (it's
|
||||
# closing the outermost nesting level, outside the arguments block)
|
||||
tool_state.arg_nesting_level += result.count("{")
|
||||
tool_state.arg_nesting_level -= result.count("}")
|
||||
if tool_state.arg_nesting_level < 0:
|
||||
result = "".join(result.split("}")[:-2]) + "}" # e.g. "4}}\n" -> "4}"
|
||||
|
||||
tool = ChatCompletionStreamOutputDeltaToolCall(
|
||||
function=ChatCompletionStreamOutputFunction(
|
||||
arguments=result,
|
||||
),
|
||||
index=0,
|
||||
type="function",
|
||||
id=None,
|
||||
)
|
||||
|
||||
yield self.build_chunk(None, _request_id, role=None, tool_calls=[tool])
|
||||
continue
|
||||
# ====== END OF TOOL CALL LOGIC ======
|
||||
|
||||
# All non-tool related tokens are emitted as assistant messages
|
||||
yield self.build_chunk(result, _request_id, role="assistant")
|
||||
yield self.build_chunk(None, _request_id, role=None, finish_reason="stop")
|
||||
|
||||
thread.join()
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise
|
||||
yield f'data: {{"error": "{str(e)}"}}'
|
||||
|
||||
finally:
|
||||
thread.join()
|
||||
|
||||
return StreamingResponse(stream_response(generation_streamer, request_id), media_type="text/event-stream")
|
||||
|
||||
@staticmethod
|
||||
def get_quantization_config(model_args: ServeArguments) -> Optional["BitsAndBytesConfig"]:
|
||||
if model_args.load_in_4bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
# For consistency with model weights, we use the same value as `torch_dtype`
|
||||
bnb_4bit_compute_dtype=model_args.torch_dtype,
|
||||
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
||||
bnb_4bit_quant_storage=model_args.torch_dtype,
|
||||
)
|
||||
elif model_args.load_in_8bit:
|
||||
quantization_config = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
else:
|
||||
quantization_config = None
|
||||
|
||||
return quantization_config
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
self, model_id_and_revision: str, args: ServeArguments
|
||||
) -> tuple[PreTrainedModel, PreTrainedTokenizerFast]:
|
||||
logger.warning(f"Loading {model_id_and_revision}")
|
||||
|
||||
if "@" in model_id_and_revision:
|
||||
model_id, revision = model_id_and_revision.split("@", 1)
|
||||
else:
|
||||
model_id, revision = model_id_and_revision, "main"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
)
|
||||
|
||||
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
|
||||
quantization_config = self.get_quantization_config(args)
|
||||
|
||||
model_kwargs = {
|
||||
"revision": revision,
|
||||
"attn_implementation": args.attn_implementation,
|
||||
"torch_dtype": torch_dtype,
|
||||
"device_map": "auto",
|
||||
"quantization_config": quantization_config,
|
||||
"trust_remote_code": args.trust_remote_code,
|
||||
}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
|
||||
|
||||
if model.generation_config.max_new_tokens is not None and model.generation_config.max_new_tokens < 256:
|
||||
model.generation_config.max_new_tokens = 256
|
||||
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
model = model.to(args.device)
|
||||
|
||||
self.loaded_model = model_id_and_revision
|
||||
|
||||
print("Loaded model", model_id_and_revision)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve = ServeCommand()
|
||||
serve.run()
|
||||
|
@ -54,7 +54,7 @@ deps = {
|
||||
"protobuf": "protobuf",
|
||||
"psutil": "psutil",
|
||||
"pyyaml": "pyyaml>=5.1",
|
||||
"pydantic": "pydantic",
|
||||
"pydantic": "pydantic>=2",
|
||||
"pytest": "pytest>=7.2.0",
|
||||
"pytest-asyncio": "pytest-asyncio",
|
||||
"pytest-rerunfailures": "pytest-rerunfailures",
|
||||
|
@ -27,6 +27,8 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.decoders import DecodeStream
|
||||
from torch.profiler import profile, schedule, tensorboard_trace_handler
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -72,6 +74,7 @@ class GenerationOutput:
|
||||
error: Optional[str] = None
|
||||
status: RequestStatus = RequestStatus.PENDING
|
||||
created_time: float = field(default_factory=time.time)
|
||||
next_token: Optional[int] = field(default_factory=int)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -96,6 +99,7 @@ class RequestState:
|
||||
eos_token_id: int = -1
|
||||
created_time: float = field(default_factory=time.time)
|
||||
error: Optional[str] = None
|
||||
next_token: Optional[str] = None
|
||||
|
||||
def current_len(self) -> int:
|
||||
"""Get the current length of the sequence (prompt + generated tokens)."""
|
||||
@ -139,6 +143,7 @@ class RequestState:
|
||||
generated_tokens=self.static_outputs,
|
||||
logprobs=[],
|
||||
error=self.error,
|
||||
next_token=self.next_token,
|
||||
)
|
||||
|
||||
|
||||
@ -764,6 +769,9 @@ class ContinuousBatchProcessor:
|
||||
|
||||
self.setup_static_tensors()
|
||||
|
||||
self.tokenizer = Tokenizer.from_pretrained(self.config._name_or_path)
|
||||
self.decode_stream = DecodeStream(skip_special_tokens=True)
|
||||
|
||||
@traced(standalone=True)
|
||||
def setup_static_tensors(self):
|
||||
T = self.max_batch_tokens
|
||||
@ -995,7 +1003,7 @@ class ContinuousBatchProcessor:
|
||||
def _maybe_send_output(self, state: RequestState, token: int):
|
||||
"""Send output to the queue based on streaming mode and request state."""
|
||||
if self.streaming:
|
||||
state.next_token = token
|
||||
state.next_token = self.decode_stream.step(self.tokenizer, state.static_outputs[-1])
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
elif state.status == RequestStatus.FINISHED:
|
||||
self.output_queue.put(state.to_generation_output())
|
||||
@ -1102,6 +1110,7 @@ class ContinuousBatchingManager:
|
||||
self.profile = getattr(generation_config, "profile", False)
|
||||
self.manual_eviction = manual_eviction
|
||||
self.batch_processor: Optional[ContinuousBatchProcessor] = None
|
||||
self.decode_stream = DecodeStream(skip_special_tokens=True)
|
||||
|
||||
@traced
|
||||
def start(self):
|
||||
|
@ -57,10 +57,12 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "ccl", "hpu": "hccl"}
|
||||
backend_map = {"cuda": "nccl", "cpu": "gloo", "xpu": "xccl", "hpu": "hccl"}
|
||||
backend = backend_map.get(device_type)
|
||||
if device_type == "cpu" and int(os.environ.get("CCL_WORKER_COUNT", 0)):
|
||||
backend = "ccl"
|
||||
if device_type == "xpu" and not is_torch_greater_or_equal("2.8", accept_dev=True):
|
||||
backend = "ccl"
|
||||
|
||||
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
current_device = getattr(torch, device_type)
|
||||
|
@ -301,10 +301,10 @@ class Gemma3nTextConfig(PretrainedConfig):
|
||||
|
||||
class Gemma3nAudioConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`], based on Gogole's
|
||||
[Universal Speech Model](). It is used to instantiate an Gemma3nAudioEncoder model according to the specified
|
||||
arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
|
||||
configuration to that of the Gemma 3n E4B, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`]. It is used to instantiate
|
||||
an `Gemma3nAudioEncoder` model according to the specified arguments, defining the model architecture. Instantiating
|
||||
a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.,
|
||||
[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||
|
||||
Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
|
||||
the documentation from [`Gemma3nAudioConfig`] for more information.
|
||||
|
@ -912,7 +912,7 @@ class Gemma3nAudioConformerBlock(nn.Module):
|
||||
|
||||
|
||||
class Gemma3nAudioEncoder(PreTrainedModel):
|
||||
"""A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037"""
|
||||
"""An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture."""
|
||||
|
||||
config_class = Gemma3nAudioConfig
|
||||
|
||||
|
@ -313,10 +313,10 @@ class Gemma3nTextConfig(Gemma2Config, PretrainedConfig):
|
||||
|
||||
class Gemma3nAudioConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`], based on Gogole's
|
||||
[Universal Speech Model](). It is used to instantiate an Gemma3nAudioEncoder model according to the specified
|
||||
arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
|
||||
configuration to that of the Gemma 3n E4B, e.g. [google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||
This is the configuration class to store the configuration of a [`Gemma3nAudioEncoder`]. It is used to instantiate
|
||||
an `Gemma3nAudioEncoder` model according to the specified arguments, defining the model architecture. Instantiating
|
||||
a configuration with the defaults will yield a similar configuration to that of the Gemma 3n E4B, e.g.,
|
||||
[google/gemma-3n-E4B](https://huggingface.co/google/gemma-3n-E4B).
|
||||
|
||||
Configuration objects that inherit from [`Gemma3nAudioConfig`] and can be used to control the model outputs. Read
|
||||
the documentation from [`Gemma3nAudioConfig`] for more information.
|
||||
@ -1473,7 +1473,7 @@ class Gemma3nAudioConformerBlock(nn.Module):
|
||||
|
||||
|
||||
class Gemma3nAudioEncoder(PreTrainedModel):
|
||||
"""A Universal Speech Encoder -- https://arxiv.org/abs/2303.01037"""
|
||||
"""An audio encoder based on the [Universal Speech Model](https://arxiv.org/abs/2303.01037) architecture."""
|
||||
|
||||
config_class = Gemma3nAudioConfig
|
||||
|
||||
|
@ -121,6 +121,7 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
do_convert_rgb: bool,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]],
|
||||
device: Optional[Union[str, torch.device]],
|
||||
disable_grouping: Optional[bool],
|
||||
):
|
||||
"""
|
||||
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
|
||||
@ -173,7 +174,7 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
resized_height, resized_width = height, width
|
||||
|
||||
# Group images by size for batched resizing
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
|
||||
resized_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
if do_resize:
|
||||
@ -191,7 +192,7 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
|
||||
# Group images by size for further processing
|
||||
# Needed in case do_resize is False, or resize returns images with different sizes
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
|
||||
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
|
||||
processed_images_grouped = {}
|
||||
for shape, stacked_images in grouped_images.items():
|
||||
# Fused rescale and normalize
|
||||
@ -249,6 +250,7 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
||||
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
disable_grouping: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@ -323,6 +325,7 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
do_convert_rgb=do_convert_rgb,
|
||||
input_data_format=input_data_format,
|
||||
device=device,
|
||||
disable_grouping=disable_grouping,
|
||||
)
|
||||
pixel_values.extend(patches)
|
||||
vision_grid_thws.append(image_grid_thw)
|
||||
@ -351,11 +354,11 @@ class Glm4vImageProcessorFast(BaseImageProcessorFast):
|
||||
|
||||
factor = patch_size * merge_size
|
||||
resized_height, resized_width = smart_resize(
|
||||
t=self.temporal_patch_size,
|
||||
num_frames=self.temporal_patch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
temporal_factor=self.temporal_patch_size,
|
||||
factor=factor,
|
||||
t_factor=self.temporal_patch_size,
|
||||
)
|
||||
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
|
||||
return grid_h * grid_w
|
||||
|
@ -287,6 +287,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
|
||||
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -324,7 +325,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
is_causal=False,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.squeeze(0)
|
||||
@ -1013,7 +1014,7 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
attention_mask = attention_mask.to(total_input_ids.device)
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
@ -1043,7 +1044,6 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
|
||||
llm_pos_ids_list = []
|
||||
video_frame_num = 1
|
||||
image_index, video_index = 0, 0
|
||||
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
@ -1085,9 +1085,7 @@ class Glm4vModel(Glm4vPreTrainedModel):
|
||||
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
||||
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
|
||||
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
||||
|
||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
|
||||
video_index += 1
|
||||
|
@ -516,6 +516,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
|
||||
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -553,7 +554,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scale,
|
||||
is_causal=False,
|
||||
is_causal=self.is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.squeeze(0)
|
||||
@ -1112,7 +1113,7 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
attention_mask = attention_mask.to(total_input_ids.device)
|
||||
for i, input_ids in enumerate(total_input_ids):
|
||||
input_ids = input_ids[attention_mask[i] == 1]
|
||||
@ -1142,7 +1143,6 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
|
||||
llm_pos_ids_list = []
|
||||
video_frame_num = 1
|
||||
image_index, video_index = 0, 0
|
||||
|
||||
for modality_type, start_idx, end_idx in input_type_group:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
@ -1184,9 +1184,7 @@ class Glm4vModel(Qwen2_5_VLModel):
|
||||
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
||||
|
||||
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
|
||||
|
||||
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
|
||||
|
||||
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
|
||||
|
||||
video_index += 1
|
||||
|
@ -336,6 +336,11 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
if num_input_ids is not None:
|
||||
weights = weights[:, :, num_input_ids:, :]
|
||||
|
||||
# Since we ignore `decoder_input_ids` in the DTW and in the case where we generated only one token (for which we don't have cross attentions, see below comments),
|
||||
# the DTW sequence length is 0 and we should return only 0.0s for the token timestamps
|
||||
if weights.shape[2] == 0:
|
||||
return timestamps
|
||||
|
||||
if num_frames is None or isinstance(num_frames, int):
|
||||
# Normalize and smoothen the weights.
|
||||
std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
@ -366,9 +371,12 @@ class WhisperGenerationMixin(GenerationMixin):
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] * time_precision
|
||||
|
||||
# each predicted token has a corresponding timestamp, expect the eos token for which we don't retrieve cross attentions
|
||||
# each predicted token has a corresponding timestamp, expect the eos token (or last predicted token) for which we don't retrieve cross attentions
|
||||
# (indeed contrary to OAI that re-run a full foward to retreive cross attentions for each token and therefore also the last one predicted, we retreive
|
||||
# cross attentions directly from the auto-regressive generation, so we don't have cross attentiosn for the token at the end of the sequence. Nevertheless,
|
||||
# that is not important since we expect this last token to be the eos token)
|
||||
# 1. for decoder_input_ids, we set the timestamps to 0.0
|
||||
# 2. for the eos token, we simply duplicate the timestamp of the last non-eos token
|
||||
# 2. for the eos token (or last predicted token), we simply duplicate the timestamp of the last non-eos token
|
||||
timestamps[batch_idx] = torch.cat(
|
||||
[torch.zeros(num_input_ids), torch.tensor(jump_times), torch.tensor([jump_times[-1]])]
|
||||
)
|
||||
|
@ -292,6 +292,30 @@ except importlib.metadata.PackageNotFoundError:
|
||||
_essentia_version = False
|
||||
|
||||
|
||||
_pydantic_available = importlib.util.find_spec("pydantic") is not None
|
||||
try:
|
||||
_pydantic_version = importlib.metadata.version("pydantic")
|
||||
logger.debug(f"Successfully imported pydantic version {_pydantic_version}")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_pydantic_available = False
|
||||
|
||||
|
||||
_fastapi_available = importlib.util.find_spec("fastapi") is not None
|
||||
try:
|
||||
_fastapi_version = importlib.metadata.version("fastapi")
|
||||
logger.debug(f"Successfully imported pydantic version {_fastapi_version}")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_fastapi_available = False
|
||||
|
||||
|
||||
_uvicorn_available = importlib.util.find_spec("uvicorn") is not None
|
||||
try:
|
||||
_uvicorn_version = importlib.metadata.version("uvicorn")
|
||||
logger.debug(f"Successfully imported pydantic version {_uvicorn_version}")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
_uvicorn_available = False
|
||||
|
||||
|
||||
_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
|
||||
try:
|
||||
_pretty_midi_version = importlib.metadata.version("pretty_midi")
|
||||
@ -473,6 +497,18 @@ def is_essentia_available():
|
||||
return _essentia_available
|
||||
|
||||
|
||||
def is_pydantic_available():
|
||||
return _pydantic_available
|
||||
|
||||
|
||||
def is_fastapi_available():
|
||||
return _fastapi_available
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
return _uvicorn_available
|
||||
|
||||
|
||||
def is_pretty_midi_available():
|
||||
return _pretty_midi_available
|
||||
|
||||
@ -1843,6 +1879,23 @@ VISION_IMPORT_ERROR = """
|
||||
`pip install pillow`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
PYDANTIC_IMPORT_ERROR = """
|
||||
{0} requires the pydantic library but it was not found in your environment. You can install it with pip:
|
||||
`pip install pydantic`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
FASTAPI_IMPORT_ERROR = """
|
||||
{0} requires the fastapi library but it was not found in your environment. You can install it with pip:
|
||||
`pip install fastapi`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
UVICORN_IMPORT_ERROR = """
|
||||
{0} requires the uvicorn library but it was not found in your environment. You can install it with pip:
|
||||
`pip install uvicorn`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
PYTESSERACT_IMPORT_ERROR = """
|
||||
@ -1966,6 +2019,9 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("yt_dlp", (is_yt_dlp_available, YT_DLP_IMPORT_ERROR)),
|
||||
("rich", (is_rich_available, RICH_IMPORT_ERROR)),
|
||||
("keras_nlp", (is_keras_nlp_available, KERAS_NLP_IMPORT_ERROR)),
|
||||
("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)),
|
||||
("fastapi", (is_fastapi_available, FASTAPI_IMPORT_ERROR)),
|
||||
("uvicorn", (is_uvicorn_available, UVICORN_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
78
tests/commands/test_chat.py
Normal file
78
tests/commands/test_chat.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import transformers.commands.transformers_cli as cli
|
||||
from transformers.commands.chat import ChatArguments, ChatCommand
|
||||
from transformers.testing_utils import CaptureStd
|
||||
|
||||
|
||||
class ChatCLITest(unittest.TestCase):
|
||||
def test_help(self):
|
||||
with patch("sys.argv", ["transformers", "chat", "--help"]), CaptureStd() as cs:
|
||||
with self.assertRaises(SystemExit):
|
||||
cli.main()
|
||||
self.assertIn("chat interface", cs.out.lower())
|
||||
|
||||
@patch.object(ChatCommand, "run")
|
||||
def test_cli_dispatch(self, run_mock):
|
||||
args = ["transformers", "chat", "hf-internal-testing/tiny-random-gpt2"]
|
||||
with patch("sys.argv", args):
|
||||
cli.main()
|
||||
run_mock.assert_called_once()
|
||||
|
||||
def test_parsed_args(self):
|
||||
with (
|
||||
patch.object(ChatCommand, "__init__", return_value=None) as init_mock,
|
||||
patch.object(ChatCommand, "run") as run_mock,
|
||||
patch(
|
||||
"sys.argv",
|
||||
[
|
||||
"transformers",
|
||||
"chat",
|
||||
"test-model",
|
||||
"max_new_tokens=64",
|
||||
],
|
||||
),
|
||||
):
|
||||
cli.main()
|
||||
init_mock.assert_called_once()
|
||||
run_mock.assert_called_once()
|
||||
parsed_args = init_mock.call_args[0][0]
|
||||
self.assertEqual(parsed_args.model_name_or_path_or_address, "test-model")
|
||||
self.assertEqual(parsed_args.generate_flags, ["max_new_tokens=64"])
|
||||
|
||||
|
||||
class ChatUtilitiesTest(unittest.TestCase):
|
||||
def test_save_and_clear_chat(self):
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
|
||||
args = ChatArguments(save_folder=str(tmp_path))
|
||||
args.model_name_or_path_or_address = "test-model"
|
||||
|
||||
chat_history = [{"role": "user", "content": "hi"}]
|
||||
filename = ChatCommand.save_chat(chat_history, args)
|
||||
self.assertTrue(os.path.isfile(filename))
|
||||
|
||||
cleared = ChatCommand.clear_chat_history()
|
||||
self.assertEqual(cleared, [])
|
||||
|
||||
def test_parse_generate_flags(self):
|
||||
dummy = ChatCommand.__new__(ChatCommand)
|
||||
parsed = ChatCommand.parse_generate_flags(dummy, ["temperature=0.5", "max_new_tokens=10"])
|
||||
self.assertEqual(parsed["temperature"], 0.5)
|
||||
self.assertEqual(parsed["max_new_tokens"], 10)
|
47
tests/commands/test_serving.py
Normal file
47
tests/commands/test_serving.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import transformers.commands.transformers_cli as cli
|
||||
from transformers.commands.serving import ServeCommand
|
||||
from transformers.testing_utils import CaptureStd
|
||||
|
||||
|
||||
class ServeCLITest(unittest.TestCase):
|
||||
def test_help(self):
|
||||
with patch("sys.argv", ["transformers", "serve", "--help"]), CaptureStd() as cs:
|
||||
with self.assertRaises(SystemExit):
|
||||
cli.main()
|
||||
self.assertIn("serve", cs.out.lower())
|
||||
|
||||
def test_parsed_args(self):
|
||||
with (
|
||||
patch.object(ServeCommand, "__init__", return_value=None) as init_mock,
|
||||
patch.object(ServeCommand, "run") as run_mock,
|
||||
patch("sys.argv", ["transformers", "serve", "--host", "0.0.0.0", "--port", "9000"]),
|
||||
):
|
||||
cli.main()
|
||||
init_mock.assert_called_once()
|
||||
run_mock.assert_called_once()
|
||||
parsed_args = init_mock.call_args[0][0]
|
||||
self.assertEqual(parsed_args.host, "0.0.0.0")
|
||||
self.assertEqual(parsed_args.port, 9000)
|
||||
|
||||
def test_build_chunk(self):
|
||||
dummy = ServeCommand.__new__(ServeCommand)
|
||||
dummy.args = type("Args", (), {})()
|
||||
chunk = ServeCommand.build_chunk(dummy, "hello", "req0", finish_reason="stop")
|
||||
self.assertIn("chat.completion.chunk", chunk)
|
||||
self.assertIn("data:", chunk)
|
@ -24,6 +24,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, is_
|
||||
from transformers.generation.configuration_utils import GenerationConfig
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
is_flash_attn_2_available,
|
||||
require_flash_attn,
|
||||
require_read_token,
|
||||
@ -136,6 +137,9 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
||||
class Cohere2IntegrationTest(unittest.TestCase):
|
||||
input_text = ["Hello I am doing", "Hi today"]
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_model_bf16(self):
|
||||
model_id = "CohereForAI/c4ai-command-r7b-12-2024"
|
||||
EXPECTED_TEXTS = [
|
||||
|
@ -29,6 +29,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
is_flaky,
|
||||
require_timm,
|
||||
require_torch,
|
||||
@ -804,34 +805,62 @@ class GroundingDinoModelIntegrationTests(unittest.TestCase):
|
||||
with torch.no_grad():
|
||||
outputs = model(**text_inputs, **image_inputs)
|
||||
|
||||
# Loss differs by CPU and GPU, also this can be changed in future.
|
||||
expected_loss_dict = {
|
||||
"loss_ce": torch.tensor(1.1147),
|
||||
"loss_bbox": torch.tensor(0.2031),
|
||||
"loss_giou": torch.tensor(0.5819),
|
||||
"loss_ce_0": torch.tensor(1.1941),
|
||||
"loss_bbox_0": torch.tensor(0.1978),
|
||||
"loss_giou_0": torch.tensor(0.5524),
|
||||
"loss_ce_1": torch.tensor(1.1621),
|
||||
"loss_bbox_1": torch.tensor(0.1909),
|
||||
"loss_giou_1": torch.tensor(0.5892),
|
||||
"loss_ce_2": torch.tensor(1.1641),
|
||||
"loss_bbox_2": torch.tensor(0.1892),
|
||||
"loss_giou_2": torch.tensor(0.5626),
|
||||
"loss_ce_3": torch.tensor(1.1943),
|
||||
"loss_bbox_3": torch.tensor(0.1941),
|
||||
"loss_giou_3": torch.tensor(0.5607),
|
||||
"loss_ce_4": torch.tensor(1.0956),
|
||||
"loss_bbox_4": torch.tensor(0.2008),
|
||||
"loss_giou_4": torch.tensor(0.5836),
|
||||
"loss_ce_enc": torch.tensor(16226.3164),
|
||||
"loss_bbox_enc": torch.tensor(0.3063),
|
||||
"loss_giou_enc": torch.tensor(0.7380),
|
||||
}
|
||||
# Loss differs by CPU and accelerator, also this can be changed in future.
|
||||
expected_loss_dicts = Expectations(
|
||||
{
|
||||
("xpu", 3): {
|
||||
"loss_ce": torch.tensor(1.1147),
|
||||
"loss_bbox": torch.tensor(0.2031),
|
||||
"loss_giou": torch.tensor(0.5819),
|
||||
"loss_ce_0": torch.tensor(1.1941),
|
||||
"loss_bbox_0": torch.tensor(0.1978),
|
||||
"loss_giou_0": torch.tensor(0.5524),
|
||||
"loss_ce_1": torch.tensor(1.1621),
|
||||
"loss_bbox_1": torch.tensor(0.1909),
|
||||
"loss_giou_1": torch.tensor(0.5892),
|
||||
"loss_ce_2": torch.tensor(1.1641),
|
||||
"loss_bbox_2": torch.tensor(0.1892),
|
||||
"loss_giou_2": torch.tensor(0.5626),
|
||||
"loss_ce_3": torch.tensor(1.1943),
|
||||
"loss_bbox_3": torch.tensor(0.1941),
|
||||
"loss_giou_3": torch.tensor(0.5592),
|
||||
"loss_ce_4": torch.tensor(1.0956),
|
||||
"loss_bbox_4": torch.tensor(0.2037),
|
||||
"loss_giou_4": torch.tensor(0.5813),
|
||||
"loss_ce_enc": torch.tensor(16226.3164),
|
||||
"loss_bbox_enc": torch.tensor(0.3063),
|
||||
"loss_giou_enc": torch.tensor(0.7380),
|
||||
},
|
||||
("cuda", None): {
|
||||
"loss_ce": torch.tensor(1.1147),
|
||||
"loss_bbox": torch.tensor(0.2031),
|
||||
"loss_giou": torch.tensor(0.5819),
|
||||
"loss_ce_0": torch.tensor(1.1941),
|
||||
"loss_bbox_0": torch.tensor(0.1978),
|
||||
"loss_giou_0": torch.tensor(0.5524),
|
||||
"loss_ce_1": torch.tensor(1.1621),
|
||||
"loss_bbox_1": torch.tensor(0.1909),
|
||||
"loss_giou_1": torch.tensor(0.5892),
|
||||
"loss_ce_2": torch.tensor(1.1641),
|
||||
"loss_bbox_2": torch.tensor(0.1892),
|
||||
"loss_giou_2": torch.tensor(0.5626),
|
||||
"loss_ce_3": torch.tensor(1.1943),
|
||||
"loss_bbox_3": torch.tensor(0.1941),
|
||||
"loss_giou_3": torch.tensor(0.5607),
|
||||
"loss_ce_4": torch.tensor(1.0956),
|
||||
"loss_bbox_4": torch.tensor(0.2008),
|
||||
"loss_giou_4": torch.tensor(0.5836),
|
||||
"loss_ce_enc": torch.tensor(16226.3164),
|
||||
"loss_bbox_enc": torch.tensor(0.3063),
|
||||
"loss_giou_enc": torch.tensor(0.7380),
|
||||
},
|
||||
}
|
||||
) # fmt: skip
|
||||
expected_loss_dict = expected_loss_dicts.get_expectation()
|
||||
|
||||
expected_loss = torch.tensor(32482.2305)
|
||||
|
||||
for key in expected_loss_dict:
|
||||
self.assertTrue(torch.allclose(outputs.loss_dict[key], expected_loss_dict[key], atol=1e-3))
|
||||
torch.testing.assert_close(outputs.loss_dict[key], expected_loss_dict[key], rtol=1e-5, atol=1e-3)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.loss, expected_loss, atol=1e-3))
|
||||
|
@ -30,6 +30,8 @@ from transformers import (
|
||||
InstructBlipVisionConfig,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
require_accelerate,
|
||||
require_bitsandbytes,
|
||||
require_torch,
|
||||
@ -722,6 +724,9 @@ def prepare_img():
|
||||
@require_torch
|
||||
@slow
|
||||
class InstructBlipModelIntegrationTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=False)
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
def test_inference_vicuna_7b(self):
|
||||
@ -739,13 +744,24 @@ class InstructBlipModelIntegrationTest(unittest.TestCase):
|
||||
outputs = model.generate(**inputs, max_new_tokens=30)
|
||||
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
|
||||
|
||||
expected_outputs = [32001] * 32 + [2, 1724, 338, 22910, 1048, 445, 1967, 29973, 450, 22910, 9565, 310, 445, 1967, 338, 393, 263, 767, 338, 13977, 292, 22095, 373, 278, 1250, 310, 263, 13328, 20134, 29963, 1550, 19500, 373, 263, 19587, 4272, 11952, 29889] # fmt: off
|
||||
expected_outputs = Expectations(
|
||||
{
|
||||
("xpu", 3): [32001] * 32 + [2, 1724, 338, 22910, 1048, 445, 1967, 29973, 450, 22910, 9565, 310, 445, 1967, 338, 393, 263, 767, 338, 13977, 292, 22095, 373, 278, 1250, 310, 263, 13328, 20134, 29963, 1550, 19500, 1623, 263, 19587, 4272, 11952, 29889],
|
||||
("cuda", None): [32001] * 32 + [2, 1724, 338, 22910, 1048, 445, 1967, 29973, 450, 22910, 9565, 310, 445, 1967, 338, 393, 263, 767, 338, 13977, 292, 22095, 373, 278, 1250, 310, 263, 13328, 20134, 29963, 1550, 19500, 373, 263, 19587, 4272, 11952, 29889],
|
||||
}
|
||||
) # fmt: off
|
||||
expected_output = expected_outputs.get_expectation()
|
||||
|
||||
self.assertEqual(outputs[0].tolist(), expected_outputs)
|
||||
self.assertEqual(
|
||||
generated_text,
|
||||
"What is unusual about this image? The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV while driving on a busy city street.",
|
||||
)
|
||||
expected_texts = Expectations(
|
||||
{
|
||||
("xpu", 3): "What is unusual about this image? The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV while driving down a busy city street.",
|
||||
("cuda", None): "What is unusual about this image? The unusual aspect of this image is that a man is ironing clothes on the back of a yellow SUV while driving on a busy city street.",
|
||||
}
|
||||
) # fmt: off
|
||||
expected_text = expected_texts.get_expectation()
|
||||
|
||||
self.assertEqual(outputs[0].tolist(), expected_output)
|
||||
self.assertEqual(generated_text, expected_text)
|
||||
|
||||
def test_inference_flant5_xl(self):
|
||||
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
|
||||
|
@ -430,7 +430,7 @@ class InternVLQwen2IntegrationTest(unittest.TestCase):
|
||||
|
||||
expected_outputs = Expectations(
|
||||
{
|
||||
("xpu", 3): 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate"',
|
||||
("xpu", 3): 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of',
|
||||
("cuda", 7): 'user\n\nDescribe this image\nassistant\nThe image shows a street scene with a traditional Chinese archway, known as a "Chinese Gate" or "Chinese Gate of',
|
||||
}
|
||||
) # fmt: skip
|
||||
@ -793,7 +793,7 @@ class InternVLLlamaIntegrationTest(unittest.TestCase):
|
||||
decoded_output = processor.decode(output[0], skip_special_tokens=True)
|
||||
expected_outputs = Expectations(
|
||||
{
|
||||
("xpu", 3): "user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden path leads to calm lake,\nNature's peaceful grace.",
|
||||
("xpu", 3): "user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.",
|
||||
("cuda", 7): 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.',
|
||||
("cuda", 8): 'user\n\nWrite a haiku for this image\nassistant\nMajestic snow-capped peaks,\nWooden dock stretches to the sea,\nSilent water mirrors.',
|
||||
}
|
||||
|
@ -17,6 +17,8 @@ import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
cleanup,
|
||||
require_read_token,
|
||||
require_torch_large_accelerator,
|
||||
slow,
|
||||
@ -78,10 +80,17 @@ class Llama4IntegrationTest(unittest.TestCase):
|
||||
},
|
||||
]
|
||||
|
||||
def tearDown(self):
|
||||
cleanup(torch_device, gc_collect=True)
|
||||
|
||||
def test_model_17b_16e_fp16(self):
|
||||
EXPECTED_TEXT = [
|
||||
'system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white'
|
||||
] # fmt: skip
|
||||
EXPECTED_TEXTS = Expectations(
|
||||
{
|
||||
("xpu", 3): ['system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach with a blue sky and a body of water in the background. The cow is brown with a white face'],
|
||||
("cuda", None): ['system\n\nYou are a helpful assistant.user\n\nWhat is shown in this image?assistant\n\nThe image shows a cow standing on a beach, with a blue sky and a body of water in the background. The cow is brown with a white'],
|
||||
}
|
||||
) # fmt: skip
|
||||
EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation()
|
||||
|
||||
inputs = self.processor.apply_chat_template(
|
||||
self.messages_1, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True
|
||||
|
@ -1797,7 +1797,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
])
|
||||
# fmt: on
|
||||
|
||||
torch.testing.assert_close(generated_ids[0], EXPECTED_OUTPUT)
|
||||
torch.testing.assert_close(generated_ids, EXPECTED_OUTPUT)
|
||||
|
||||
EXPECTED_TRANSCRIPT = [
|
||||
{
|
||||
@ -1971,16 +1971,11 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
},
|
||||
{
|
||||
"text": " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and",
|
||||
# "timestamp": (39.80, 45.36),
|
||||
# above is the expected output on A100.
|
||||
# on CI T4s, due to sligth difference in floating points operations, expected is below
|
||||
"timestamp": (39.80, 45.38),
|
||||
"timestamp": (39.80, 45.36),
|
||||
},
|
||||
{
|
||||
"text": " can discover in it but little of rocky Ithaca.",
|
||||
# "timestamp": (45.36, 49.0),
|
||||
# see above
|
||||
"timestamp": (45.38, 49.0),
|
||||
"timestamp": (45.36, 49.0),
|
||||
},
|
||||
{
|
||||
"text": " Lenell's pictures are a sort of up-guards-and-atom paintings, and Mason's exquisite ittles",
|
||||
@ -2220,7 +2215,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
torch.tensor([44.7000, 44.8600, 44.9400, 45.1400, 45.1400, 45.2800, 45.6200, 45.9000, 46.2600, 47.1600, 47.4800, 47.7400, 48.1000, 48.2800, 48.4000, 48.6200, 48.8400, 49.0400, 49.2800, 49.4800, 49.6600, 49.9400, 50.5400, 50.5400]),
|
||||
torch.tensor([50.5400, 50.6600, 50.8800, 51.2400, 51.7200, 52.8400, 52.9600]),
|
||||
torch.tensor([52.9600, 53.0400, 53.2600, 53.4200, 53.5800, 53.9200, 54.1200, 54.7200, 54.9400, 55.2600, 55.6200, 55.9800, 56.5600, 56.8000, 56.9200, 57.3600, 57.9200, 58.1600, 58.5200, 58.6400, 58.8200, 59.4200, 59.4200]),
|
||||
torch.tensor([58.6800, 59.1400, 59.5400, 59.9200, 60.1400, 60.3800, 60.8400, 61.6000, 62.2400, 62.4200, 62.4200])
|
||||
torch.tensor([58.6800, 59.1400, 59.5400, 59.9200, 60.1400, 60.3800, 60.8400, 61.6000, 62.2400, 62.3800, 62.4400])
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
@ -2894,10 +2889,10 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
" Folks, if you watch the show, you know I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories, developing the central headline pawns, definitely maneuvering an oh-so-topical night to F6, faming of classic Sicilian, named or variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a Fisher shows in lip-nitsky attack that culminates in the elegant lethal slow-played all-pass on checkmate that is my nightly monologue, but sometimes sometimes, sometimes folks I sometimes I start a little wake-up side down in the monkey bars of a condemned playground on a super fun site, get all hept up on goofballs, rummage that would discard a tag bag of defective toys, yank out a fistball of disembodied doll limbs, toss them on a stain kid's place mad from a defunct denies, set up a table inside a rusty cargo container down by the warf and challenge toothless drifters to the godless bughouse blitz of tournament that is my segment.",
|
||||
" Folks, I spent a lot of time right over there night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing on those topical anti-lock breaks and power steering, painstakingly stitching, leather seating, so soft, it would make JD power and her associates blush. To create the luxury sedan that is my nightly monologue, but sometimes I just sometimes folks, I lurched to consciousness in the back of an abandoned school bus and slapped myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen-moon render a gas tank out of an empty big gulp, filled with white claw and de-natured alcohol, then light a match, letter-rip, and the dis-mented one-man soapbox derby of news that is my segment. Meanwhile.",
|
||||
" Ladies and gentlemen, you know, I spent a lot of time right over there, raising the finest hosting news cattle firmly, yet tenderly milking the latest headlines from their jokes, swollen teats, churning the daily stories into the decadent Provincil style triple cream-breed. It is my nightly monologue, but sometimes sometimes I stagger home hungry after being released by the police and root around in the neighbor's trash can for an old milk carton scraped out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-dawn street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire than a hunker down in hallucinate while eating the Listeria latent demon custard of news that is my segment.",
|
||||
" Folks, you watched this show. You know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Icol Greg Waferandi, who carefully die them in a pallet of bright, zesty shades, and adorn them in the finest, most topical inlay work, using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddle stitching, and line it with bees, wax, coated linen, finally attach a mallet hammered strap, pearl hardware, and close-shet to create for you the one-of-a-kind, hout-cout-tour, earned me his burkin bag that is my monologue, but sometimes, sometimes, folks. Sometimes, sometimes, sometimes, sometimes I wake up in the last car of an abandoned roller coaster at Coney Island, where I'm hiding from the triads, I huff some engine lubricants out of a safe way bag, and staggered down the shore to tear the sail off a beach skoener, then I ripped the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel, lovely folks. And use it to stitch the sail into a loose pouch-like rock sack, and I stow in the back of a garbage truck to the junkyard, where I pick through to the debris for only the broken toys that make me the saddest, until I have loaded for you. The hobo fugitives bug out bindle of news that is my segment. Meanwhile!",
|
||||
" Folks, you watched this show. You know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Icol Greg Waferandi, who carefully die them in a pallet of bright, zesty shades, and adorn them in the finest, most topical inlay work, using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddle stitching, and line it with bees, wax, coated linen, finally attach a mallet hammered strap, pearl hardware, and close-shet to create for you the one-of-a-kind, hout-cout-tour, earned me his burkin bag that is my monologue, but sometimes, sometimes, folks. Sometimes, sometimes, sometimes, sometimes I wake up in the last car of an abandoned roller coaster at Coney Island, where I'm hiding from the triads, I huff some engine lubricants out of a safe way bag, and staggered down the shore to tear the sail off a beach skoener, then I ripped the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel, lovely folks. And use it to stitch the sail into a loose pouch-like rock sack, and I stow in the back of a garbage truck to the junkyard, where I pick through to the debris for only the broken toys that make me the saddest, until I have loaded for you. The hobo fugitives bug out bindle of news that is my segment. Meanwhile.",
|
||||
" You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's big stories right over there. meticulously selecting the most topical chakra affirming scented candles, using Feng Shui, to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue, but sometimes just sometimes, I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself and use fry oil, wrap my hands and some old duct tape I stole from a broken car window, pound a six pack of blueberry hardcelser and a sack of pills I stole from a parked ambulance, then arm wrestle a raccoon in the back alley vision quest of news that is my segment.",
|
||||
" You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press, black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards and a face plate, and finally using fluted strips of white alloyed molding I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes, folks. Sometimes, just sometimes, I come to my senses fully naked on the deck of a pirate, besieged, melee, container ship that picked me up floating on the detached door of a port of potty in the Indian Ocean. Then, after a sunstroke induced realization of the crew of this ship plans to sell me and exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe and a pool chain that accepting my new role as captain and declaring myself King of the Windark Seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create these shopping wet pirate crown of news that is my segment. Meanwhile!",
|
||||
" Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most relevant midnight valencio oranges. And doubts at all, and I find delimane de voyage cognac, before from bang and basting them tables, I deserve you the James Beard Award worthy creeps to ZET. That is my nightly monologue, but sometimes sometimes folks, I wake up in the baggage hole of Greyhound bus. It's being hoisted by the scrapyard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps, busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strained pair of sweatpants. As ovenmets to extract and serve the demented transience pound cake of news that is my segment.",
|
||||
' You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press, black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards and a face plate, and finally using fluted strips of white alloyed molding I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes, folks. Sometimes, just sometimes, I come to my senses fully naked on the deck of a pirate, besieged, melee, container ship that picked me up floating on the detached door of a port of potty in the Indian Ocean. Then, after a sunstroke induced realization of the crew of this ship plans to sell me and exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe and a pool chain that accepting my new role as captain and declaring myself King of the Windark Seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create these shopping wet pirate crown of news that is my segment. Meanwhile, young man.',
|
||||
" Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most relevant midnight valencio oranges. And doubts at all, and I find delimane de voyage cognac, before from bang and basting them tables, I deserve you the James Beard Award worthy creeps to ZET. That is my nightly monologue, but sometimes sometimes folks, I wake up in the baggage hole of Greyhound bus. It's being hoisted by the scrapyard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps, busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strained pair of sweatpants. As ovenmets to extract and serve the Demented Transience pound cake of news that is my segment.",
|
||||
" Folks, if you watch the show and I hope you do, I spend a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines, working with the best trainers money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen that is my nightly monologue. But sometimes sometimes folks I break into an unincorporated veterinary genetics lab. And grab whatever test tubes I can find and then under a grow light I got from it a discarded chia pet. I mixed the pill for DNA of a horse and whatever was in a tube labeled Keith Cole and extra. Slering the concoction with caffeine pills and a microwave bread bowl, I screamed sing a prayer to Janice initiator of human life and God of transformation as a half horse, half man freak, seizes to life before me and the hideous collection of loose animal parts and corrupted men tissue that is my segment. Meanwhile.",
|
||||
]
|
||||
# fmt: on
|
||||
@ -2935,6 +2930,7 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
||||
"renormalize_logits": True, # necessary to match OAI beam search implementation
|
||||
}
|
||||
|
||||
set_seed(0)
|
||||
result = model.generate(**inputs, **gen_kwargs)
|
||||
decoded_all = processor.batch_decode(result, skip_special_tokens=True)
|
||||
|
||||
|
@ -22,6 +22,7 @@ from parameterized import parameterized
|
||||
|
||||
from transformers import AutoTokenizer, Zamba2Config, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
Expectations,
|
||||
require_bitsandbytes,
|
||||
require_flash_attn,
|
||||
require_torch,
|
||||
@ -678,14 +679,23 @@ class Zamba2ModelIntegrationTest(unittest.TestCase):
|
||||
]
|
||||
, dtype=torch.float32) # fmt: skip
|
||||
|
||||
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
|
||||
[
|
||||
0.1966, 6.3449, 3.8350, -5.7291, -6.5106, -6.5104, -6.5103, -6.5104,
|
||||
-6.5103, -6.5104, -6.5106, -6.5105, 7.8700, 13.5434, -6.5104, -6.5096,
|
||||
-6.5106, -6.5102, -6.5106, -6.5106, -6.5105, -6.5106, -6.5104, -6.5106,
|
||||
-6.5105, -6.5106, -6.5106, -6.5113, -6.5102, -6.5105, -6.5108, -6.5105,
|
||||
-6.5104, -6.5106, -6.5106, -6.5104, -6.5106, -6.5107, -6.5103, -6.5105 ]
|
||||
, dtype=torch.float32) # fmt: skip
|
||||
EXPECTED_LOGITS_NO_GRAD_1S = Expectations(
|
||||
{
|
||||
("xpu", 3): torch.tensor([0.2027, 6.3481, 3.8392, -5.7279, -6.5090, -6.5088, -6.5087, -6.5088,
|
||||
-6.5087, -6.5088, -6.5090, -6.5089, 7.8796, 13.5483, -6.5088, -6.5080,
|
||||
-6.5090, -6.5086, -6.5090, -6.5090, -6.5089, -6.5090, -6.5088, -6.5090,
|
||||
-6.5089, -6.5090, -6.5090, -6.5097, -6.5086, -6.5089, -6.5092, -6.5089,
|
||||
-6.5088, -6.5090, -6.5090, -6.5088, -6.5090, -6.5091, -6.5087, -6.5089],
|
||||
dtype=torch.float32),
|
||||
("cuda", None): torch.tensor([0.1966, 6.3449, 3.8350, -5.7291, -6.5106, -6.5104, -6.5103, -6.5104,
|
||||
-6.5103, -6.5104, -6.5106, -6.5105, 7.8700, 13.5434, -6.5104, -6.5096,
|
||||
-6.5106, -6.5102, -6.5106, -6.5106, -6.5105, -6.5106, -6.5104, -6.5106,
|
||||
-6.5105, -6.5106, -6.5106, -6.5113, -6.5102, -6.5105, -6.5108, -6.5105,
|
||||
-6.5104, -6.5106, -6.5106, -6.5104, -6.5106, -6.5107, -6.5103, -6.5105],
|
||||
dtype=torch.float32),
|
||||
}
|
||||
) # fmt: skip
|
||||
EXPECTED_LOGITS_NO_GRAD_1 = EXPECTED_LOGITS_NO_GRAD_1S.get_expectation()
|
||||
|
||||
torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
|
@ -520,14 +520,14 @@ class Pipeline4BitTest(Base4bitTest):
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
@apply_skip_if_not_implemented
|
||||
class Bnb4bitTestMultiGpu(Base4bitTest):
|
||||
class Bnb4bitTestMultiAccelerator(Base4bitTest):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def test_multi_gpu_loading(self):
|
||||
def test_multi_accelerator_loading(self):
|
||||
r"""
|
||||
This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
|
||||
Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
|
||||
This tests that the model has been loaded and can be used correctly on a multi-accelerator setup.
|
||||
Let's just try to load a model on 2 accelerators and see if it works. The model we test has ~2GB of total, 3GB should suffice
|
||||
"""
|
||||
device_map = {
|
||||
"transformer.word_embeddings": 0,
|
||||
|
@ -24,7 +24,7 @@ from transformers.testing_utils import (
|
||||
backend_device_count,
|
||||
get_torch_dist_unique_port,
|
||||
require_huggingface_hub_greater_or_equal,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@ -168,6 +168,6 @@ class TestTensorParallel(TestCasePlus):
|
||||
del non_tp_tensor, tp_tensor
|
||||
|
||||
|
||||
@require_torch_multi_gpu
|
||||
class TestTensorParallelCuda(TestTensorParallel):
|
||||
@require_torch_multi_accelerator
|
||||
class TestTensorParallelAccelerator(TestTensorParallel):
|
||||
nproc_per_node = backend_device_count(torch_device)
|
||||
|
Loading…
Reference in New Issue
Block a user