Skip to content

Commit 21b18d7

Browse files
committed
Added CB support for InternVL
Signed-off-by: Asmita Goswami <[email protected]>
1 parent 374ddbb commit 21b18d7

File tree

5 files changed

+262
-28
lines changed

5 files changed

+262
-28
lines changed

QEfficient/generation/embedding_handler.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
operations, separating them from the main text generation logic.
1313
"""
1414

15-
from typing import Any, Dict, Optional, Tuple
15+
from io import BytesIO
16+
from typing import Any, Dict, List, Optional, Tuple
1617

1718
import numpy as np
1819
import requests
1920
import torch
2021
from PIL import Image
21-
from transformers import AutoImageProcessor
22+
from transformers import AutoImageProcessor, AutoTokenizer
2223

2324
from QEfficient.generation.cloud_infer import QAICInferenceSession
2425
from QEfficient.utils.logging_utils import logger
@@ -37,6 +38,9 @@ def __init__(
3738
qeff_model: Optional[QAICInferenceSession],
3839
vision_session: Optional[QAICInferenceSession],
3940
processor: Optional[AutoImageProcessor],
41+
tokenizer: Optional[AutoTokenizer],
42+
image_height: Optional[int] = None,
43+
image_width: Optional[int] = None,
4044
config: Optional[Dict[str, Any]] = None,
4145
lang_session: Optional[QAICInferenceSession] = None,
4246
):
@@ -46,12 +50,16 @@ def __init__(
4650
Args:
4751
vision_session: QAICInferenceSession for vision model
4852
processor: AutoImageProcessor for image preprocessing
53+
tokenizer: AutoTokenizer for text tokenization
4954
config: Configuration dictionary with vision model parameters
5055
lang_session: Optional language session for coordination (to avoid resource conflicts)
5156
"""
5257
self._qeff_model = qeff_model
5358
self._vision_session = vision_session
5459
self._processor = processor
60+
self._tokenizer = tokenizer
61+
self._image_height = image_height
62+
self._image_width = image_width
5563
self._config = config or {}
5664
self._lang_session = lang_session # Store language session for coordination
5765

@@ -70,6 +78,71 @@ def is_available(self) -> bool:
7078
"""
7179
return self._vision_session is not None and self._processor is not None
7280

81+
def prepare_internVL_inputs(self, img_url: str, query: str) -> Dict[str, np.ndarray]:
82+
"""
83+
Prepare inputs for InternVL model
84+
85+
Args:
86+
image_url: URL or path to image
87+
query: Text query to process with image
88+
prompt = [query]
89+
"""
90+
if not self._tokenizer:
91+
raise ValueError("Tokenizer is required for InternVL input preparation")
92+
prompt = query
93+
pixel_values = []
94+
num_patches_list = []
95+
questions = []
96+
img = requests.get(img_url, stream=True)
97+
image = Image.open(BytesIO(img.content)).convert("RGB")
98+
99+
if self._image_height and self._image_width:
100+
image = image.resize((self._image_height, self._image_width))
101+
else:
102+
logger.warning("Height and Width not specified. Using default image size for num_patches = 13.")
103+
image = image.resize((1000, 747))
104+
105+
# preprocess the resized image
106+
pixel_value = self._processor.load_image(image, max_num=12)
107+
num_patches_list.append(pixel_value.shape[0])
108+
pixel_values.append(pixel_value)
109+
110+
question = "<image>\n" + prompt
111+
questions.append(question)
112+
113+
pixel_values = torch.cat(pixel_values, dim=0)
114+
115+
# Chat Template information for prompt preprocessing
116+
messages: List[List[str]] = []
117+
roles = ("<|im_start|>user\n", "<|im_start|>assistant\n")
118+
prompt = self._processor(pixel_values, questions, messages, roles, num_patches_list=num_patches_list)
119+
120+
inputs = self._tokenizer(prompt, return_tensors="pt")
121+
inputs["pixel_values"] = pixel_values.clone()
122+
123+
# Convert to numpy arrays
124+
vision_inputs = {}
125+
for k, v in inputs.items():
126+
if k in {
127+
"pixel_values",
128+
"image_masks",
129+
"image_input_idx",
130+
"valid_idx",
131+
"aspect_ratio_ids",
132+
"aspect_ratio_mask",
133+
}:
134+
vision_inputs[k] = np.array(v)
135+
136+
# Convert specific inputs to float16
137+
vision_inputs_fp16 = {"pixel_values", "image_masks"}
138+
for k in vision_inputs_fp16:
139+
if k in vision_inputs:
140+
vision_inputs[k] = vision_inputs[k].astype("float16")
141+
142+
lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
143+
144+
return vision_inputs, lang_inputs
145+
73146
def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -> Dict[str, np.ndarray]:
74147
"""
75148
Download and preprocess image into model inputs
@@ -323,7 +396,13 @@ def get_processed_inputs(
323396

324397
try:
325398
## Get vlm inputs ##
326-
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)
399+
if (
400+
hasattr(self._qeff_model.model.config, "model_type")
401+
and self._qeff_model.model.config.model_type == "internvl_chat"
402+
):
403+
vision_inputs, lang_inputs = self.prepare_internVL_inputs(image_url, query)
404+
else:
405+
vision_inputs, lang_inputs = self.prepare_vlm_inputs(image_url, query, prefill_seq_len)
327406

328407
# Handle padding for language model
329408
pad_token_id = 1

QEfficient/generation/vlm_generation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def __init__(
8686
enable_debug_logs: bool = False,
8787
write_io_dir: Optional[str] = None,
8888
full_batch_size: Optional[int] = None,
89+
image_height: Optional[int] = None,
90+
image_width: Optional[int] = None,
8991
is_tlm: bool = False,
9092
include_sampler: bool = False,
9193
return_pdfs: bool = False,
@@ -139,6 +141,9 @@ def __init__(
139141
)
140142
self.qeff_model = qeff_model
141143
self.processor = processor
144+
self.tokenizer = tokenizer
145+
self.image_height = image_height
146+
self.image_width = image_width
142147
self._vision_qpc_path = vision_qpc_path
143148
self.device_id = device_id # Store device_id for vision components
144149
self.enable_debug_logs = enable_debug_logs # Store for vision components
@@ -169,6 +174,9 @@ def _init_vision_components(self):
169174
qeff_model=self.qeff_model,
170175
vision_session=self._vision_session,
171176
processor=self.processor,
177+
tokenizer=self.tokenizer,
178+
image_height=self.image_height,
179+
image_width=self.image_width,
172180
config=vision_config,
173181
lang_session=self._session, # Pass language session for coordination
174182
)

QEfficient/transformers/models/internvl/modeling_internvl.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
from typing import Optional
9+
810
import torch
911
import torch.nn as nn
1012
import torch.nn.functional as F
@@ -34,7 +36,15 @@ def __init__(self, model):
3436
self.config = self.model.language_model.config
3537
self.language_model = self.model.language_model
3638

37-
def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values):
39+
def forward(
40+
self,
41+
input_ids,
42+
vision_embeds,
43+
position_ids,
44+
image_idx,
45+
past_key_values,
46+
batch_index: Optional[torch.LongTensor] = None,
47+
):
3848
input_embeds = self.model.language_model.get_input_embeddings()(input_ids)
3949
B, N, C = input_embeds.shape
4050
image_input_embeds = input_embeds.reshape(B * N, C)
@@ -55,7 +65,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va
5565
inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds)
5666
inputs_embeds = inputs_embeds.reshape(B, N, C)
5767
outputs = self.model.language_model(
58-
inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True
68+
inputs_embeds=inputs_embeds,
69+
position_ids=position_ids,
70+
past_key_values=past_key_values,
71+
batch_index=batch_index,
72+
use_cache=True,
5973
)
6074
image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
6175
return outputs.logits, vision_embeds, image_idx, outputs.past_key_values
@@ -75,6 +89,9 @@ def get_specializations(
7589
ctx_len: int,
7690
img_size: int,
7791
kv_offload: bool = False,
92+
continuous_batching: bool = False,
93+
kv_cache_batch_size: Optional[int] = None,
94+
full_batch_size: Optional[int] = None,
7895
**compiler_options,
7996
):
8097
num_patches = compiler_options.pop("num_patches", None)
@@ -104,24 +121,38 @@ def get_specializations(
104121
"batched_num_patches": batch_size * num_patches,
105122
}
106123
]
107-
lang = [
108-
{
109-
"batch_size": batch_size,
110-
"seq_len": prefill_seq_len,
111-
"ctx_len": ctx_len,
112-
"num_patches": num_patches,
113-
"img_size": img_size,
114-
"vision_size": vision_size,
115-
},
116-
{
117-
"batch_size": batch_size,
118-
"seq_len": "1",
119-
"ctx_len": ctx_len,
120-
"num_patches": num_patches,
121-
"img_size": img_size,
122-
"vision_size": vision_size,
123-
},
124-
]
124+
lang_prefill = {
125+
"batch_size": 1 if continuous_batching else batch_size,
126+
"seq_len": prefill_seq_len,
127+
"ctx_len": ctx_len,
128+
"num_patches": num_patches,
129+
"img_size": img_size,
130+
"vision_size": vision_size,
131+
}
132+
if continuous_batching:
133+
lang_prefill["full_batch_size"] = kv_cache_batch_size
134+
else:
135+
lang_prefill["batch_size"] = kv_cache_batch_size
136+
if full_batch_size:
137+
lang_prefill["full_batch_exec_size"] = full_batch_size
138+
139+
lang_decode = {
140+
"batch_size": full_batch_size if continuous_batching else batch_size,
141+
"seq_len": "1",
142+
"ctx_len": ctx_len,
143+
"num_patches": num_patches,
144+
"img_size": img_size,
145+
"vision_size": vision_size,
146+
}
147+
148+
if continuous_batching:
149+
lang_decode["full_batch_size"] = kv_cache_batch_size
150+
else:
151+
lang_decode["batch_size"] = kv_cache_batch_size
152+
153+
lang = []
154+
lang.append(lang_prefill)
155+
lang.append(lang_decode)
125156

126157
specializations = {}
127158

@@ -130,18 +161,22 @@ def get_specializations(
130161
specializations["lang"] = lang
131162
return specializations, compiler_options
132163
else:
164+
lang[0].pop("vision_size")
165+
lang[1].pop("vision_size")
133166
return lang, compiler_options
134167

135-
def get_onnx_dynamic_axes(self, kv_offload: bool = False):
168+
def get_onnx_dynamic_axes(self, kv_offload: bool = False, continuous_batching: bool = False):
136169
# Define dynamic axes
137170
vision_dynamic_axes = {}
138171
lang_dynamic_axes = {}
139172
lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"}
140173
lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"}
141174
lang_dynamic_axes["vision_embeds"] = {1: "vision_size"}
175+
if continuous_batching:
176+
lang_dynamic_axes["batch_index"] = {0: "batch_size"}
142177
vision_dynamic_axes["pixel_values"] = {0: "batched_num_patches", 2: "img_size", 3: "img_size"}
143178

144-
pkv_dynamic_axes = {0: "batch_size", 2: "ctx_len"}
179+
pkv_dynamic_axes = {0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"}
145180
for i in range(self.language_model.config.num_hidden_layers):
146181
for kv in ["key", "value"]:
147182
lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
@@ -173,7 +208,7 @@ def get_output_names(self, kv_offload: bool = False):
173208
return lang_output_names
174209
return output_names
175210

176-
def get_dummy_inputs(self, kv_offload: bool = False):
211+
def get_dummy_inputs(self, kv_offload: bool = False, continuous_batching: bool = False):
177212
if vis_cfg := getattr(self.config, "vision_config", None):
178213
img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE)
179214
else:
@@ -222,10 +257,13 @@ def get_dummy_inputs(self, kv_offload: bool = False):
222257
)
223258
lang_inputs["image_idx"] = torch.zeros((1, 1), dtype=torch.int64)
224259

260+
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
261+
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
262+
225263
# Add data for KV
226264
kv_cache_shape = get_padding_shape_from_config(
227265
config=self.language_model.config,
228-
batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
266+
batch_size=fbs if continuous_batching else bs,
229267
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
230268
)
231269

@@ -234,6 +272,9 @@ def get_dummy_inputs(self, kv_offload: bool = False):
234272
for kv in ["key", "value"]:
235273
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
236274

275+
if continuous_batching:
276+
lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1)
277+
237278
inputs = {}
238279
if kv_offload:
239280
inputs["vision"] = vision_inputs

QEfficient/transformers/models/modeling_auto.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1190,6 +1190,8 @@ def generate(
11901190
device_ids: List[int] = None,
11911191
runtime_ai100: bool = True,
11921192
generation_len: Optional[int] = None,
1193+
image_height: Optional[int] = None,
1194+
image_width: Optional[int] = None,
11931195
) -> Union[torch.Tensor, np.ndarray]:
11941196
"""
11951197
Generates output by executing the compiled QPC(s) on Cloud AI 100 Hardware cards.
@@ -1246,6 +1248,8 @@ def generate(
12461248
device_id=device_ids, # if device_ids is not None else [0],
12471249
ctx_len=ctx_len_comp,
12481250
full_batch_size=fbs,
1251+
image_height=image_height,
1252+
image_width=image_width,
12491253
)
12501254

12511255
# Call generate method
@@ -2273,7 +2277,11 @@ def from_pretrained(
22732277

22742278
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
22752279
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
2276-
model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
2280+
model,
2281+
kv_offload=kv_offload,
2282+
continuous_batching=continuous_batching,
2283+
pretrained_model_name_or_path=pretrained_model_name_or_path,
2284+
**kwargs,
22772285
)
22782286
return cls(
22792287
model,

0 commit comments

Comments
 (0)