mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 13:20:12 +06:00
Fix torch.onnx.export of Qwen2-VL vision encoder (#34852)
* Fix torch.onnx.export of Qwen2-VL vision encoder
This PR fixes onnx export support for the vision encoder of Qwen2-VL, which converts the `cu_seqlens` to `torch.int32`, leading to errors later on when using the values for slicing.
c57eafdaa1/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py (L1044-L1046)
## Error:
```
onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] (op_type:Slice, node name: /blocks.0/attn/Slice_4): axes has inconsistent type tensor(int64)
```
## Code to reproduce issue:
```py
import requests
from PIL import Image
import torch
from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
)
# Constants
VISION_MODEL_NAME = "vision_encoder.onnx"
# Load model and processor
model_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
model = Qwen2VLForConditionalGeneration.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)
# Prepare inputs
url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
image = Image.open(requests.get(url, stream=True).raw)
conversation = [
{
"role": "user",
"content": [
{ "type": "image" },
{ "type": "text", "text": "Describe this image."},
],
},
]
images = [image]
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(text=[text_prompt], images=images, padding=True, return_tensors="pt")
## Vision model
vision_inputs = dict(
pixel_values=inputs["pixel_values"],
grid_thw=inputs["image_grid_thw"],
)
vision_inputs_positional = tuple(vision_inputs.values())
vision_outputs = model.visual.forward(*vision_inputs_positional) # Test forward pass
torch.onnx.export(
model.visual,
args=vision_inputs_positional,
f=VISION_MODEL_NAME,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=list(vision_inputs.keys()),
output_names=["image_features"],
dynamic_axes={
"pixel_values": {
0: "batch_size * grid_t * grid_h * grid_w",
1: "channel * temporal_patch_size * patch_size * patch_size",
},
"grid_thw": {0: "batch_size"},
"image_features": {0: "batch_size * grid_t * grid_h * grid_w"},
},
)
# Load and check the exported model model
import onnx
model = onnx.load(VISION_MODEL_NAME)
onnx.checker.check_model(model, full_check=True)
inferred = onnx.shape_inference.infer_shapes(model, check_type=True)
```
* Formatting
* [run-slow] qwen2_vl
This commit is contained in:
parent
d5cf91b346
commit
1f6b423f0c
@ -1025,7 +1025,7 @@ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
dim=0, dtype=grid_thw.dtype
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user