diff --git a/.gitignore b/.gitignore index 5753fddae..ac6826c87 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ third_party/from_source/openssl/* log* *.csv *.as* +*.egg-info diff --git a/multimodal/README.md b/multimodal/README.md index a91b07a2d..4eb702760 100644 --- a/multimodal/README.md +++ b/multimodal/README.md @@ -8,7 +8,7 @@ DashInfer VLMs is a toolkit to support Vision Language Models (VLMs) inference b ## Supported Models - Qwen2-VL 2B/7B/72B -- Qwen2.5-VL 2B/7B/72B (Only support transformers vit engine) +- Qwen2.5-VL 3B/7B/32B/72B ## Architecture ![alt text](resource/dashinfer-vlm-arch.png) diff --git a/multimodal/dashinfer_vlm/visual_embedding/DFN_vit.py b/multimodal/dashinfer_vlm/visual_embedding/DFN_vit.py index 71cd20f32..5deb7535c 100644 --- a/multimodal/dashinfer_vlm/visual_embedding/DFN_vit.py +++ b/multimodal/dashinfer_vlm/visual_embedding/DFN_vit.py @@ -21,32 +21,11 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VisionTransformerPretrainedModel, ) +from .utils import default_weight_loader -# from .model_loader import default_weight_loader dtype = "fp32" -def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: - """Default weight loader.""" - try: - if param.numel() == 1 and loaded_weight.numel() == 1: - # Sometimes scalar values aren't considered tensors with shapes - # so if both param and loaded_weight are a scalar, - # "broadcast" instead of copy - param.data.fill_(loaded_weight.item()) - else: - assert param.size() == loaded_weight.size(), ( - f"Attempted to load weight ({loaded_weight.size()}) " - f"into parameter ({param.size()})" - ) - - param.data.copy_(loaded_weight) - except Exception: - # NOTE: This exception is added for the purpose of setting breakpoint to - # debug weight loading issues. - raise - - def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: return x * torch.sigmoid(1.702 * x) @@ -400,25 +379,8 @@ def forward(self, x, cu_seqlens, rotary_pos_emb, b) -> torch.Tensor: x = x + self.mlp(self.norm2(x)) return x - -# class Qwen2VisionTransformer(nn.Module): class Qwen2VisionTransformer(Qwen2VisionTransformerPretrainedModel): def __init__(self, config): - # img_size: int = 378, - # patch_size: int = 14, - # temporal_patch_size: int = 2, - # spatial_merge_size: int = 2, - # in_chans: int = 3, - # hidden_size: int = 1000, - # embed_dim: int = 768, - # depth: int = 12, - # num_heads: int = 16, - # mlp_ratio: float = 4.0, - # norm_layer: nn.Module = partial(LayerNorm, eps=1e-6), - # use_flash_attention: bool = False, - # *args, - # **kwargs, - # ) -> None: super().__init__(config) self.spatial_merge_size = config.spatial_merge_size diff --git a/multimodal/dashinfer_vlm/visual_embedding/DFN_vit_2_5.py b/multimodal/dashinfer_vlm/visual_embedding/DFN_vit_2_5.py new file mode 100644 index 000000000..84796dfc5 --- /dev/null +++ b/multimodal/dashinfer_vlm/visual_embedding/DFN_vit_2_5.py @@ -0,0 +1,497 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm +import torch.utils.checkpoint +from flash_attn.flash_attn_interface import flash_attn_varlen_func +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel +) +from .utils import default_weight_loader + + +def quick_gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) + + +class QuickGELU(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)""" + + def __init__(self, inplace: bool = False) -> None: + super(QuickGELU, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return quick_gelu(input) + + +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2RotaryEmbedding +class Qwen2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal Sections for t,h,w in Multimodal inputs + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + if mrope_section: + cos = cos[position_ids] + sin = sin[position_ids] + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + else: + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + cos = freqs.cos() + sin = freqs.sin() + # rotary_2 interleaved start + cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0) + sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0) + output = (tensor * cos) + (rotate_half(tensor) * sin) + # rotary_2 interleaved end + output = output.type_as(tensor) + return output + + +class VisionRotaryEmbedding(nn.Module): + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._freqs_cached = None + + def forward(self, seqlen: int) -> torch.Tensor: + seqlen *= 2 + self.inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim) + ) + seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = seq.unsqueeze(1) * self.inv_freq.unsqueeze(0) + # freqs = torch.outer(seq, self.inv_freq) + return freqs[:seqlen] + + +class PatchEmbed(nn.Module): + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_chans: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_chans, hidden_size, kernel_size=kernel_size, stride=kernel_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + seqlen = x.shape[0] + x = x.view(seqlen, -1, self.temporal_patch_size, self.patch_size, self.patch_size) + x = self.proj(x).view(seqlen, self.hidden_size) + return x + + +class PatchMerger(nn.Module): + def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: + super().__init__() + self.out_hidden_size = context_dim * (spatial_merge_size**2) + self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.out_hidden_size, self.out_hidden_size), + nn.GELU(), + nn.Linear(self.out_hidden_size, dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.mlp(self.ln_q(x).view(-1, self.out_hidden_size)) + return x + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, out_hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(out_hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class VisionMlp(nn.Module): + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.out_hidden_size = dim + self.intermediate_size = hidden_dim + self.gate_proj = nn.Linear(self.out_hidden_size, self.intermediate_size) + self.up_proj = nn.Linear(self.out_hidden_size, self.intermediate_size) + self.down_proj = nn.Linear(self.intermediate_size, self.out_hidden_size) + self.act_fn = QuickGELU() + + def forward(self, hidden_state): + return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + + +class VisionAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 16, use_flash_attention: bool = False) -> None: + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.use_flash_attention = use_flash_attention + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None, b=1 + ) -> torch.Tensor: + # return self.flash_forward(x, cu_seqlens, rotary_pos_emb) + + n = self.num_heads + d = self.head_dim + + N, _ = x.shape + + qkv = self.qkv(x) + qkv = qkv.reshape(N, 3, self.num_heads, -1) + q, k, v = qkv.split(1, dim=1) + + q = q.view(1, -1, n, d) + k = k.view(1, -1, n, d) + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + q = q.view(b, -1, n, d) + k = k.view(b, -1, n, d) + v = v.view(b, -1, n, d) + + softmax_scale = math.pow(d, -0.25) + b = v.size(0) + q = q.view(b, -1, n, d) + k = k.view(b, -1, n, d) + v = v.view(b, -1, n, d) + + q = q.permute(0, 2, 1, 3) * softmax_scale + k = k.permute(0, 2, 3, 1) * softmax_scale + v = v.permute(0, 2, 1, 3) + + attn = torch.matmul(q, k) + attn = F.softmax(attn, dim=-1).type_as(attn) + x = torch.matmul(attn, v).permute(0, 2, 1, 3) + x = x.reshape(b, -1, n * d) + x = self.proj(x.contiguous()) + x = x.view(-1, n * d) + return x + + def flash_forward( + self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + L, _ = x.shape + q, k, v = self.qkv(x).reshape(L, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + if flash_attn_varlen_func is not None and q.dtype in [torch.float16, torch.bfloat16]: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + x = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen) + x = x.reshape(L, -1) + else: + attention_mask = torch.zeros([1, L, L], device=q.device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + q = q.transpose(0, 1) + k = k.transpose(0, 1) + v = v.transpose(0, 1) + x = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0).transpose(0, 1).reshape(L, -1) + x = self.proj(x) + return x + + +class Qwen2VLVisionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + intermediate_size: float, + norm_layer: nn.Module = partial(Qwen2RMSNorm, eps=1e-6), + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + + self.attn = VisionAttention(dim, num_heads=num_heads, use_flash_attention=use_flash_attention) + self.mlp = VisionMlp(dim=dim, hidden_dim=intermediate_size) + + def forward(self, x, cu_seqlens, rotary_pos_emb, b) -> torch.Tensor: + x = x + self.attn(self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, b=b) + x = x + self.mlp(self.norm2(x)) + return x + +class Qwen2_5VisionTransformer(Qwen2_5_VisionTransformerPretrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.spatial_merge_size = config.spatial_merge_size + self.temporal_patch_size = config.temporal_patch_size + self.in_chans = config.in_chans + self.hidden_size = config.hidden_size + self.num_heads = config.num_heads + self.intermediate_size = config.intermediate_size + self.patch_size = config.patch_size + self.depth = config.depth + self.norm_layer = partial(Qwen2RMSNorm, eps=1e-6) + self.out_hidden_size = config.out_hidden_size + self.fullatt_block_indexes = config.fullatt_block_indexes + self.window_size = config.window_size + use_flash_attention = False + + self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + + self.patch_embed = PatchEmbed( + patch_size=self.patch_size, + temporal_patch_size=self.temporal_patch_size, + in_chans=self.in_chans, + hidden_size=self.hidden_size, + ) + + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList( + [ + Qwen2VLVisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + intermediate_size=self.intermediate_size, + norm_layer=self.norm_layer, + use_flash_attention=use_flash_attention, + ) + for _ in range(self.depth) + ] + ) + self.merger = PatchMerger(dim=self.out_hidden_size, context_dim=self.hidden_size) + + def get_dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + def get_device(self) -> torch.device: + return self.blocks[0].mlp.fc2.weight.device + + def rot_pos_emb(self, grid_thw): + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = ( + hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + wpos_ids = ( + wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ) + .permute(0, 2, 1, 3) + .flatten() + ) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def fix_attn_bias(self): + for blk in self.blocks: + blk.attn.qkv.bias = nn.Parameter( + blk.attn.qkv.bias.view(blk.attn.num_heads, 3, -1).transpose(0, 1).reshape(-1) + ) + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h, llm_grid_w = grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) + index_padded = index_padded.reshape( + grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size + ) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size + ) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + + return window_index, cu_window_seqlens + + def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: + hidden_states = self.patch_embed(hidden_states) + rotary_pos_emb = self.rot_pos_emb(grid_thw) + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + + hidden_states = blk( + hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb, b=batch.size(0) + ) + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + + return hidden_states + + def load_weights(self, weights): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if not name.startswith("visual."): + continue + name = name.split("visual.")[1] + if "blocks" in name and "attn.proj.bias" in name: + continue + + # Note: only used for debug + if name not in params_dict.keys(): + continue + default_weight_loader(params_dict[name], loaded_weight) \ No newline at end of file diff --git a/multimodal/dashinfer_vlm/visual_embedding/utils.py b/multimodal/dashinfer_vlm/visual_embedding/utils.py new file mode 100644 index 000000000..f48234823 --- /dev/null +++ b/multimodal/dashinfer_vlm/visual_embedding/utils.py @@ -0,0 +1,21 @@ +import torch + +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + try: + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})" + ) + + param.data.copy_(loaded_weight) + except Exception: + # NOTE: This exception is added for the purpose of setting breakpoint to + # debug weight loading issues. + raise \ No newline at end of file diff --git a/multimodal/dashinfer_vlm/vl_inference/__init__.py b/multimodal/dashinfer_vlm/vl_inference/__init__.py index d7cabb50a..eaf8d1b54 100644 --- a/multimodal/dashinfer_vlm/vl_inference/__init__.py +++ b/multimodal/dashinfer_vlm/vl_inference/__init__.py @@ -2,4 +2,5 @@ Copyright (c) Alibaba, Inc. and its affiliates. @file __init__.py ''' -from ..visual_embedding.DFN_vit import Qwen2VisionTransformer \ No newline at end of file +from ..visual_embedding.DFN_vit import Qwen2VisionTransformer +from ..visual_embedding.DFN_vit_2_5 import Qwen2_5VisionTransformer \ No newline at end of file diff --git a/multimodal/dashinfer_vlm/vl_inference/utils/__init__.py b/multimodal/dashinfer_vlm/vl_inference/utils/__init__.py index 78565e442..8c268eed0 100644 --- a/multimodal/dashinfer_vlm/vl_inference/utils/__init__.py +++ b/multimodal/dashinfer_vlm/vl_inference/utils/__init__.py @@ -9,4 +9,4 @@ from .hie_allspark import * from .cache import * -from .. import Qwen2VisionTransformer \ No newline at end of file +from .. import Qwen2VisionTransformer, Qwen2_5VisionTransformer \ No newline at end of file diff --git a/multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py b/multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py index f6f6aaa41..bcf471eb8 100644 --- a/multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py +++ b/multimodal/dashinfer_vlm/vl_inference/utils/model_loader.py @@ -160,7 +160,11 @@ def serialize( self.vision_model_path = os.path.join( model_output_dir, self.pretain_model_name + ".plan" ) - onnx_trt_obj = ONNX_TRT(self.hf_model_path) + if hasattr(self.hf_model_config, "architectures") and "Qwen2_5_VLForConditionalGeneration" in self.hf_model_config.architectures: + is_qwen_2_5= True + else: + is_qwen_2_5 = False + onnx_trt_obj = ONNX_TRT(self.hf_model_path, is_qwen_2_5=is_qwen_2_5) onnx_trt_obj.export_onnx(onnxFile) onnx_trt_obj.generate_trt_engine(onnxFile, self.vision_model_path) elif self.vision_engine == "transformers": diff --git a/multimodal/dashinfer_vlm/vl_inference/utils/trt/onnx_to_plan.py b/multimodal/dashinfer_vlm/vl_inference/utils/trt/onnx_to_plan.py index 1425d0f33..d23c53358 100644 --- a/multimodal/dashinfer_vlm/vl_inference/utils/trt/onnx_to_plan.py +++ b/multimodal/dashinfer_vlm/vl_inference/utils/trt/onnx_to_plan.py @@ -17,45 +17,52 @@ from typing import Any, Dict, List, Optional import contextlib from dataclasses import dataclass - +from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig import tensorrt as trt import torch -from .. import Qwen2VisionTransformer +from .. import Qwen2VisionTransformer, Qwen2_5VisionTransformer class ONNX_TRT: - def __init__(self, model_path=None): - from transformers.models.qwen2_vl.configuration_qwen2_vl import ( - Qwen2VLVisionConfig, - ) + def __init__(self, model_path=None, is_qwen_2_5=False): + self.is_qwen_2_5 = is_qwen_2_5 + if is_qwen_2_5: + self.config = Qwen2_5_VLConfig.from_pretrained( + model_path, trust_remote_code=True, revision=None, code_revision=None + ).vision_config + self.model_path = model_path + self.input_embed_dim = ( + self.config.in_channels + * self.config.temporal_patch_size + * self.config.patch_size + * self.config.patch_size + ) + else: + self.config = Qwen2VLVisionConfig.from_pretrained( + model_path, trust_remote_code=True, revision=None, code_revision=None + ) + self.model_path = model_path + self.input_embed_dim = ( + self.config.in_channels + * self.config.temporal_patch_size + * self.config.patch_size + * self.config.patch_size + ) - self.model_path = model_path - self.config = Qwen2VLVisionConfig.from_pretrained( - model_path, trust_remote_code=True, revision=None, code_revision=None - ) - self.input_embed_dim = ( - self.config.in_channels - * self.config.temporal_patch_size - * self.config.patch_size - * self.config.patch_size - ) def export_onnx(self, onnx_file_path): print("Start converting ONNX model!") - - # class SumModule(torch.nn.Module): - # def forward(self, x, y): - # x[0][0][0] = y[0][0][1] - # return torch.sum(x, dim=1) model_path = self.model_path config = self.config + vision_model = Qwen2_5VisionTransformer if self.is_qwen_2_5 else Qwen2VisionTransformer class WrapModel(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.vision_model = Qwen2VisionTransformer(config) + self.vision_model = vision_model(config) def get_weights_iterator(model): import glob diff --git a/multimodal/dashinfer_vlm/vl_inference/utils/trt/vit_process.py b/multimodal/dashinfer_vlm/vl_inference/utils/trt/vit_process.py index 651fe0ee0..f3ad39ce4 100644 --- a/multimodal/dashinfer_vlm/vl_inference/utils/trt/vit_process.py +++ b/multimodal/dashinfer_vlm/vl_inference/utils/trt/vit_process.py @@ -5,6 +5,7 @@ import tensorrt as trt import contextlib from dataclasses import dataclass +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLVisionConfig logger = trt.Logger(trt.Logger.WARNING) @@ -300,12 +301,9 @@ def run( class VisualTRT_V2(HieModel_V2): def __init__(self, vit_engine_path, trt_vit_config, input=input): - print("loading qwen2-vit by pyhie") self.stream = torch.cuda.current_stream().cuda_stream - print(f"Loading engine from {vit_engine_path}") with open(vit_engine_path, "rb") as f: engine_buffer = f.read() - print(f"Creating session from engine {vit_engine_path}") self.session_vit = Session.from_serialized_engine(engine_buffer) self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu" self.trt_vit_config = trt_vit_config @@ -320,34 +318,19 @@ def forward(self, images, grid_thw, batch, use_flashattn=True): "input": images.to(torch.float32), "grid_thw": grid_thw.to(torch.int64), } - # visual_output_info = self.session_vit.infer_shapes( - # [TensorInfo("input", trt.DataType.FLOAT, images.shape), TensorInfo("grid_thw", trt.DataType.INT64, grid_thw.shape)]) - # visual_outputs = { - # t.name: torch.empty(tuple(t.shape), - # dtype=trt_dtype_to_torch(t.dtype), - # device="cuda") - # for t in visual_output_info - # } self.session_vit.context.set_input_shape("input", images.shape) self.session_vit.context.set_input_shape("grid_thw", grid_thw.shape) - hidden_size = self.trt_vit_config.hidden_size - embed_dim = self.trt_vit_config.embed_dim + if isinstance(self.trt_vit_config, Qwen2_5_VLVisionConfig): + hidden_size = self.trt_vit_config.out_hidden_size + else: + hidden_size = self.trt_vit_config.hidden_size spatial_merge_size = self.trt_vit_config.spatial_merge_size - image_tokens = int( - visual_inputs["input"].shape[1] - * embed_dim - / (embed_dim * (spatial_merge_size**2)) - ) + image_tokens = int(visual_inputs["input"].shape[1] * (spatial_merge_size**2)) visual_outputs = { "output": torch.empty( (1, image_tokens, hidden_size), dtype=torch.float32, device="cuda" ) } - # profiler.start("ViT") ok = self.session_vit.run(visual_inputs, visual_outputs, self.stream) - # profiler.stop("ViT") - # Vit_time = profiler.elapsed_time_in_sec("ViT") - # print(f"TensorRT-LLM ViT latency: {Vit_time:3f} sec ") assert ok, "Runtime execution failed for vit session" - return visual_outputs["output"].squeeze(0).clone() diff --git a/multimodal/tests/benchmark_openai_api.py b/multimodal/tests/benchmark_openai_api.py index 1c6a3a587..aa5e520c2 100644 --- a/multimodal/tests/benchmark_openai_api.py +++ b/multimodal/tests/benchmark_openai_api.py @@ -321,4 +321,4 @@ def print_profiling_data(total_timecost): global_end = time.time() print(f"Total time: {global_end - global_start_time :.2f} sec") - print_profiling_data(global_end - global_start_time) + print_profiling_data(global_end - global_start_time) \ No newline at end of file diff --git a/multimodal/tests/test.jpg b/multimodal/tests/test.jpg new file mode 100644 index 000000000..ce7c2a67c Binary files /dev/null and b/multimodal/tests/test.jpg differ diff --git a/multimodal/tests/test_openai_chat_completion.py b/multimodal/tests/test_openai_chat_completion.py index 15372c51a..b3b6f22cb 100644 --- a/multimodal/tests/test_openai_chat_completion.py +++ b/multimodal/tests/test_openai_chat_completion.py @@ -6,17 +6,18 @@ import argparse def test_text_image_1(client, model): + test_image = os.path.join(os.path.dirname(os.path.abspath(__file__)) , "test.jpg") response = client.chat.completions.create( model=model, messages=[ { "role": "user", "content": [ - {"type": "text", "text": "Describe the image."}, + {"type": "text", "text": "Please Read and Describe the image."}, { "type": "image_url", "image_url": { - "url": "https://farm4.staticflickr.com/3075/3168662394_7d7103de7d_z_d.jpg", + "url": test_image, }, }, ], @@ -31,29 +32,24 @@ def test_text_image_1(client, model): def test_text_multi_images(client, model): + test_image = os.path.join(os.path.dirname(os.path.abspath(__file__)) , "test.jpg") response = client.chat.completions.create( model=model, messages=[ { "role": "user", "content": [ - {"type": "text", "text": "Are these images different?"}, + {"type": "text", "text": "Describe the images?"}, { "type": "image_url", "image_url": { - "url": "https://farm4.staticflickr.com/3075/3168662394_7d7103de7d_z_d.jpg", + "url": test_image, }, }, { "type": "image_url", "image_url": { - "url": "https://farm9.staticflickr.com/8505/8441256181_4e98d8bff5_z_d.jpg", - }, - }, - { - "type": "image_url", - "image_url": { - "url": "https://farm3.staticflickr.com/2220/1572613671_7311098b76_z_d.jpg", + "url": test_image, }, }, ], @@ -127,7 +123,7 @@ def main(args, client): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--host', type=str, - default="localhost") + default="0.0.0.0") parser.add_argument('--port', type=str, default="8000") parser.add_argument('--type', type=str, default="all", choices=["all", "single_image", "multi_images", "video"])