Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ third_party/from_source/openssl/*
log*
*.csv
*.as*
*.egg-info
2 changes: 1 addition & 1 deletion multimodal/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ DashInfer VLMs is a toolkit to support Vision Language Models (VLMs) inference b

## Supported Models
- Qwen2-VL 2B/7B/72B
- Qwen2.5-VL 2B/7B/72B (Only support transformers vit engine)
- Qwen2.5-VL 3B/7B/32B/72B

## Architecture
![alt text](resource/dashinfer-vlm-arch.png)
Expand Down
40 changes: 1 addition & 39 deletions multimodal/dashinfer_vlm/visual_embedding/DFN_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,11 @@
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VisionTransformerPretrainedModel,
)
from .utils import default_weight_loader

# from .model_loader import default_weight_loader
dtype = "fp32"


def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
try:
if param.numel() == 1 and loaded_weight.numel() == 1:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param.data.fill_(loaded_weight.item())
else:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
f"into parameter ({param.size()})"
)

param.data.copy_(loaded_weight)
except Exception:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise


def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)

Expand Down Expand Up @@ -400,25 +379,8 @@ def forward(self, x, cu_seqlens, rotary_pos_emb, b) -> torch.Tensor:
x = x + self.mlp(self.norm2(x))
return x


# class Qwen2VisionTransformer(nn.Module):
class Qwen2VisionTransformer(Qwen2VisionTransformerPretrainedModel):
def __init__(self, config):
# img_size: int = 378,
# patch_size: int = 14,
# temporal_patch_size: int = 2,
# spatial_merge_size: int = 2,
# in_chans: int = 3,
# hidden_size: int = 1000,
# embed_dim: int = 768,
# depth: int = 12,
# num_heads: int = 16,
# mlp_ratio: float = 4.0,
# norm_layer: nn.Module = partial(LayerNorm, eps=1e-6),
# use_flash_attention: bool = False,
# *args,
# **kwargs,
# ) -> None:
super().__init__(config)
self.spatial_merge_size = config.spatial_merge_size

Expand Down
Loading