# GLM-4.1V
The example below demonstrates how to generate text based on an image with [`Pipeline`] or the [`AutoModel`] class.
```py
import torch
from transformers import pipeline
pipe = pipeline(
task="image-text-to-text",
model="THUDM/GLM-4.1V-9B-Thinking",
device=0,
torch_dtype=torch.bfloat16
)
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
},
{ "type": "text", "text": "Describe this image."},
]
}
]
pipe(text=messages,max_new_tokens=20, return_full_text=False)
```
```py
import torch
from transformers import Glm4vForConditionalGeneration, AutoProcessor
model = Glm4vForConditionalGeneration.from_pretrained(
"THUDM/GLM-4.1V-9B-Thinking",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa"
)
processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
messages = [
{
"role":"user",
"content":[
{
"type":"image",
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
},
{
"type":"text",
"text":"Describe this image."
}
]
}
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
```
Using GLM-4.1V with video input is similar to using it with image input.
The model can process video data and generate text based on the content of the video.
```python
from transformers import AutoProcessor, Glm4vForConditionalGeneration
import torch
processor = AutoProcessor.from_pretrained("THUDM/GLM-4.1V-9B-Thinking")
model = Glm4vForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path="THUDM/GLM-4.1V-9B-Thinking",
torch_dtype=torch.bfloat16,
device_map="cuda:0"
)
messages = [
{
"role": "user",
"content": [
{
"type": "video",
"url": "https://test-videos.co.uk/vids/bigbuckbunny/mp4/h264/720/Big_Buck_Bunny_720_10s_10MB.mp4",
},
{
"type": "text",
"text": "discribe this video",
},
],
}
]
inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", padding=True).to("cuda:0")
generated_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=1.0)
output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True)
print(output_text)
```
## Glm4vConfig
[[autodoc]] Glm4vConfig
## Glm4vTextConfig
[[autodoc]] Glm4vTextConfig
## Glm4vImageProcessor
[[autodoc]] Glm4vImageProcessor
- preprocess
## Glm4vVideoProcessor
[[autodoc]] Glm4vVideoProcessor
- preprocess
## Glm4vImageProcessorFast
[[autodoc]] Glm4vImageProcessorFast
- preprocess
## Glm4vProcessor
[[autodoc]] Glm4vProcessor
## Glm4vTextModel
[[autodoc]] Glm4vTextModel
- forward
## Glm4vModel
[[autodoc]] Glm4vModel
- forward
## Glm4vForConditionalGeneration
[[autodoc]] Glm4vForConditionalGeneration
- forward