Skip to content

Commit f3cca8e

Browse files
x574chenXiaotong Chen
authored andcommitted
dashinfer vlm: add tensorrt support for qwen2.5vl (#91)
* dashinfer vlm: add tensorrt support for qwen2.5vl * update benchmark --------- Co-authored-by: Xiaotong Chen <[email protected]>
1 parent 29326d7 commit f3cca8e

File tree

13 files changed

+573
-101
lines changed

13 files changed

+573
-101
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ third_party/from_source/openssl/*
2626
log*
2727
*.csv
2828
*.as*
29+
*.egg-info

multimodal/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ DashInfer VLMs is a toolkit to support Vision Language Models (VLMs) inference b
88

99
## Supported Models
1010
- Qwen2-VL 2B/7B/72B
11-
- Qwen2.5-VL 2B/7B/72B (Only support transformers vit engine)
11+
- Qwen2.5-VL 3B/7B/32B/72B
1212

1313
## Architecture
1414
![alt text](resource/dashinfer-vlm-arch.png)

multimodal/dashinfer_vlm/visual_embedding/DFN_vit.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,11 @@
2121
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
2222
Qwen2VisionTransformerPretrainedModel,
2323
)
24+
from .utils import default_weight_loader
2425

25-
# from .model_loader import default_weight_loader
2626
dtype = "fp32"
2727

2828

29-
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
30-
"""Default weight loader."""
31-
try:
32-
if param.numel() == 1 and loaded_weight.numel() == 1:
33-
# Sometimes scalar values aren't considered tensors with shapes
34-
# so if both param and loaded_weight are a scalar,
35-
# "broadcast" instead of copy
36-
param.data.fill_(loaded_weight.item())
37-
else:
38-
assert param.size() == loaded_weight.size(), (
39-
f"Attempted to load weight ({loaded_weight.size()}) "
40-
f"into parameter ({param.size()})"
41-
)
42-
43-
param.data.copy_(loaded_weight)
44-
except Exception:
45-
# NOTE: This exception is added for the purpose of setting breakpoint to
46-
# debug weight loading issues.
47-
raise
48-
49-
5029
def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
5130
return x * torch.sigmoid(1.702 * x)
5231

@@ -400,25 +379,8 @@ def forward(self, x, cu_seqlens, rotary_pos_emb, b) -> torch.Tensor:
400379
x = x + self.mlp(self.norm2(x))
401380
return x
402381

403-
404-
# class Qwen2VisionTransformer(nn.Module):
405382
class Qwen2VisionTransformer(Qwen2VisionTransformerPretrainedModel):
406383
def __init__(self, config):
407-
# img_size: int = 378,
408-
# patch_size: int = 14,
409-
# temporal_patch_size: int = 2,
410-
# spatial_merge_size: int = 2,
411-
# in_chans: int = 3,
412-
# hidden_size: int = 1000,
413-
# embed_dim: int = 768,
414-
# depth: int = 12,
415-
# num_heads: int = 16,
416-
# mlp_ratio: float = 4.0,
417-
# norm_layer: nn.Module = partial(LayerNorm, eps=1e-6),
418-
# use_flash_attention: bool = False,
419-
# *args,
420-
# **kwargs,
421-
# ) -> None:
422384
super().__init__(config)
423385
self.spatial_merge_size = config.spatial_merge_size
424386

0 commit comments

Comments
 (0)