Skip to content

Commit a259d62

Browse files
authored
mm: add qwen vl2.5 model support. (#86)
- add qwen vl 2.5 model support. - Qwen VL2.5 only support 'transformers' as vit engine, (trt not support yet.) - upgrade package version to make sure VL2.5 code is added. test command: server: `dashinfer_vlm_serve --model qwen/Qwen2.5-VL-3B-Instruct --vision_engine transformers --port 8000 --host=127.0.0.1` client: ``` curl http://localhost:8000/v1/chat/completions \ -H "Content-Type: application/json" \ -d \ '{"model": "qwen/Qwen2.5-VL-3B-Instruct", "messages": [{"role": "user", "content": [{ "type": "text", "text": "Describe the image." }, {"type": "image_url", "image_url": {"url": "https://farm4.staticflickr.com/3075/3168662394_7d7103de7d_z_d.jpg"}}]}], "max_completion_tokens": 1024, "top_p": 0.5, "temperature": 0.1, "frequency_penalty": 1.05 }' ``` result: ``` {"id":"chatcmpl-rxqDiCQEJweEeeB7FADiER","object":"chat.completion", "created":1747992522,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":"The image features a small hummingbird perched on a branch. The bird is positioned in the center of the scene, with its vibrant colors and delicate features clearly visible. The hummingbird appears to be enjoying its time in nature, possibly searching for food or simply resting on the branch. \n\nThere are no other birds or animals present in the image, making it a solitary moment captured in this natural setting."},"finish_reason":"stop"}],"usage":{"prompt_tokens":382,"total_tokens":95,"completion_tokens":81}} ```
1 parent 4b57232 commit a259d62

File tree

2 files changed

+79
-33
lines changed

2 files changed

+79
-33
lines changed

multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py

Lines changed: 74 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
import torch
77
import glob
88
import warnings
9-
from modelscope import snapshot_download
10-
from transformers import Qwen2VLForConditionalGeneration, AutoConfig, AutoTokenizer
11-
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
129
from tqdm import tqdm
10+
11+
from transformers import AutoConfig, AutoTokenizer, AutoProcessor
12+
13+
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
14+
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
15+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig
16+
1317
from safetensors.torch import safe_open
1418
from dashinfer import allspark
1519
from dashinfer.allspark.model_loader import HuggingFaceModel, ModelSerializerException
@@ -59,25 +63,58 @@ def load_model(
5963
# the open-source model can be loaded by huggingface
6064
try:
6165
if not os.path.isdir(self.hf_model_path):
66+
from modelscope import snapshot_download
6267
self.hf_model_path = snapshot_download(self.hf_model_path)
63-
self.torch_model = Qwen2VLForConditionalGeneration.from_pretrained(
64-
self.hf_model_path,
65-
trust_remote_code=self.trust_remote_code,
66-
torch_dtype=dtype_to_torch_dtype(self.data_type),
67-
device_map="cpu",
68-
**kwargs,
69-
).eval()
70-
self.vit_config = Qwen2VLVisionConfig.from_pretrained(
71-
self.hf_model_path,
72-
trust_remote_code=True,
73-
revision=None,
74-
code_revision=None,
75-
)
76-
self.tokenizer = AutoTokenizer.from_pretrained(
77-
self.hf_model_path,
78-
trust_remote_code=self.trust_remote_code,
79-
**kwargs,
68+
69+
# Read config to determine model architecture
70+
self.hf_model_config = AutoConfig.from_pretrained(
71+
self.hf_model_path, trust_remote_code=self.trust_remote_code
8072
)
73+
74+
if hasattr(self.hf_model_config, "architectures") and "Qwen2_5_VLForConditionalGeneration" in self.hf_model_config.architectures:
75+
self.torch_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
76+
self.hf_model_path,
77+
trust_remote_code=self.trust_remote_code,
78+
torch_dtype=dtype_to_torch_dtype(self.data_type),
79+
device_map="cpu",
80+
**kwargs,
81+
).eval()
82+
self.tokenizer = AutoTokenizer.from_pretrained(
83+
self.hf_model_path,
84+
trust_remote_code=self.trust_remote_code,
85+
**kwargs,
86+
)
87+
self.processor = AutoProcessor.from_pretrained(
88+
self.hf_model_path,
89+
trust_remote_code=self.trust_remote_code,
90+
**kwargs,
91+
)
92+
self.vit_config = Qwen2_5_VLVisionConfig.from_pretrained(
93+
self.hf_model_path,
94+
trust_remote_code=True,
95+
revision=None,
96+
code_revision=None,
97+
)
98+
else:
99+
self.torch_model = Qwen2VLForConditionalGeneration.from_pretrained(
100+
self.hf_model_path,
101+
trust_remote_code=self.trust_remote_code,
102+
torch_dtype=dtype_to_torch_dtype(self.data_type),
103+
device_map="cpu",
104+
**kwargs,
105+
).eval()
106+
self.tokenizer = AutoTokenizer.from_pretrained(
107+
self.hf_model_path,
108+
trust_remote_code=self.trust_remote_code,
109+
**kwargs,
110+
)
111+
self.vit_config = Qwen2VLVisionConfig.from_pretrained(
112+
self.hf_model_path,
113+
trust_remote_code=True,
114+
revision=None,
115+
code_revision=None,
116+
)
117+
pass
81118
except Exception as e:
82119
print(
83120
f"exception when load model: {self.hf_model_path} , exception: {e}"
@@ -102,10 +139,10 @@ def read_model_config(self):
102139
self.hf_model_config = AutoConfig.from_pretrained(
103140
self.hf_model_path, trust_remote_code=self.trust_remote_code
104141
)
105-
self.adapter = QWen2ConfigAdapter(self.hf_model_config)
106-
self.as_model_config = self.adapter.model_config
107-
if self.user_set_data_type is None:
108-
self.data_type = self.adapter.get_model_data_type()
142+
self.adapter = QWen2ConfigAdapter(self.hf_model_config)
143+
self.as_model_config = self.adapter.model_config
144+
if self.user_set_data_type is None:
145+
self.data_type = self.adapter.get_model_data_type()
109146
return self
110147

111148
def serialize(
@@ -127,17 +164,26 @@ def serialize(
127164
onnx_trt_obj.export_onnx(onnxFile)
128165
onnx_trt_obj.generate_trt_engine(onnxFile, self.vision_model_path)
129166
elif self.vision_engine == "transformers":
130-
visual_model = Qwen2VLForConditionalGeneration.from_pretrained(
167+
if hasattr(self.hf_model_config, "architectures") and "Qwen2_5_VLForConditionalGeneration" in self.hf_model_config.architectures:
168+
visual_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
131169
self.hf_model_path,
132170
trust_remote_code=self.trust_remote_code,
133171
torch_dtype=dtype_to_torch_dtype(self.data_type),
134-
device_map="cpu",
135-
attn_implementation="flash_attention_2",
172+
device_map="auto",
173+
attn_implementation="sdpa",
174+
).visual.eval()
175+
else:
176+
visual_model = Qwen2VLForConditionalGeneration.from_pretrained(
177+
self.hf_model_path,
178+
trust_remote_code=self.trust_remote_code,
179+
torch_dtype=dtype_to_torch_dtype(self.data_type),
180+
device_map="auto",
181+
attn_implementation="sdpa",
136182
).visual.eval()
137183
self.vision_model_path = visual_model
138184
else:
139185
raise ValueError(f"unsupported engine {self.vision_engine}")
140-
186+
141187
# Convert Allspark LLM
142188
enable_quant = False
143189
weight_only_quant=False

multimodal/requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
dashinfer@https://github.com/modelscope/dash-infer/releases/download/v2.0.0-rc3/dashinfer-2.0.0rc3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
22
av
3-
numpy==1.24.3
4-
requests==2.32.3
5-
nvtx==0.2.10
6-
transformers>=4.45.0
3+
numpy>=1.24.3
4+
requests>=2.32.3
5+
nvtx>=0.2.10
6+
transformers>=4.48.9
77
cachetools>=5.4.0
88
six
99
tiktoken
@@ -12,7 +12,7 @@ shortuuid
1212
fastapi
1313
pydantic_settings
1414
uvicorn
15-
cmake==3.22.6
15+
cmake>=3.22.6
1616
modelscope
1717
aiohttp
1818
onnx

0 commit comments

Comments
 (0)