Skip to content

Commit c80a19a

Browse files
committed
Updated test_image_text_to_text for CB tests
Signed-off-by: Asmita Goswami <[email protected]>
1 parent 08274bb commit c80a19a

File tree

2 files changed

+112
-3
lines changed

2 files changed

+112
-3
lines changed

QEfficient/utils/run_utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,54 @@ def __init__(
276276
self.config = config
277277
self.gen_len = max_gen_len
278278

279+
@torch.no_grad()
280+
def run_vlm_hf_model_on_pytorch_CB(self, model, images, queries):
281+
"""
282+
Function responsible for running HuggingFace ``PyTorch`` model for continuous batching
283+
and return the output tokens for each prompt/image pair.
284+
285+
``Mandatory`` Args:
286+
:model (torch.nn.module): Original ``PyTorch`` model
287+
:images (List[PIL.Image]): List of input images
288+
:queries (List[str]): List of input queries
289+
290+
Return:
291+
:List[numpy.ndarray]: List of generated output tokens for each prompt
292+
"""
293+
generated_ids = []
294+
295+
for idx, (image, query) in enumerate(zip(images, queries)):
296+
# Prepare conversation format for each image-query pair
297+
conversation = [
298+
{
299+
"role": "user",
300+
"content": [
301+
{"type": "text", "text": query},
302+
{"type": "image"},
303+
],
304+
},
305+
]
306+
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
307+
308+
# Process inputs
309+
inputs = self.processor(images=image, text=prompt, return_tensors="pt")
310+
if "pixel_values" in inputs:
311+
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
312+
313+
# Generate tokens
314+
output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False)
315+
offset_output = output[0, inputs["input_ids"].shape[1]:]
316+
317+
# Decode and print output
318+
py_output = self.processor.tokenizer.decode(offset_output).strip()
319+
print(f"Original HF Model Outputs (Torch CPU) for prompt {idx}:")
320+
print("Query:", repr(query))
321+
print("Completion:", repr(py_output))
322+
323+
generated_ids.append(offset_output.numpy())
324+
325+
return generated_ids
326+
279327
@torch.no_grad()
280328
def run_vlm_hf_model_on_pytorch(self, model, inputs):
281329
output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False)

tests/transformers/models/test_image_text_to_text_models.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
# model_name,
3939
# kv_offload,
4040
# batch_size,
41+
# full_batch_size,
4142
# prompt_len,
4243
# ctx_len,
4344
# img_size,
@@ -49,6 +50,7 @@
4950
"llava-hf/llava-1.5-7b-hf",
5051
True,
5152
1,
53+
4,
5254
784,
5355
1024,
5456
336,
@@ -60,6 +62,7 @@
6062
"llava-hf/llava-1.5-7b-hf",
6163
False,
6264
1,
65+
4,
6366
784,
6467
1024,
6568
336,
@@ -72,6 +75,7 @@
7275
# "meta-llama/Llama-4-Scout-17B-16E-Instruct",
7376
# True,
7477
# 1,
78+
# 4,
7579
# 128,
7680
# 3072,
7781
# 336,
@@ -83,6 +87,7 @@
8387
# "meta-llama/Llama-4-Scout-17B-16E-Instruct",
8488
# False,
8589
# 1,
90+
# 4,
8691
# 128,
8792
# 3072,
8893
# 336,
@@ -94,6 +99,7 @@
9499
"google/gemma-3-4b-it",
95100
True,
96101
1,
102+
4,
97103
128,
98104
3072,
99105
896,
@@ -105,6 +111,7 @@
105111
"google/gemma-3-4b-it",
106112
False,
107113
1,
114+
4,
108115
128,
109116
3072,
110117
896,
@@ -116,6 +123,7 @@
116123
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
117124
True,
118125
1,
126+
4,
119127
128,
120128
4096,
121129
1540,
@@ -127,6 +135,7 @@
127135
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
128136
False,
129137
1,
138+
4,
130139
128,
131140
4096,
132141
1540,
@@ -138,6 +147,7 @@
138147
"Qwen/Qwen2.5-VL-3B-Instruct",
139148
True,
140149
1,
150+
4,
141151
128,
142152
4096,
143153
1540,
@@ -149,6 +159,7 @@
149159
# "meta-llama/Llama-3.2-11B-Vision-Instruct",
150160
# True,
151161
# 1,
162+
# 4,
152163
# 32,
153164
# 512,
154165
# 560,
@@ -256,6 +267,7 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
256267
query: str,
257268
prompt_len: int,
258269
ctx_len: int,
270+
full_batch_size: int,
259271
max_gen_len: int = 20,
260272
batch_size: int = 1,
261273
n_layer: int = 1,
@@ -341,8 +353,56 @@ def check_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
341353
output = qeff_model.generate(inputs=inputs, generation_len=NEW_GENERATION_TOKENS, streamer=streamer)
342354
qpc_tokens = output.generated_ids[:, :-1]
343355
assert (pytorch_hf_tokens == qpc_tokens).all(), "Tokens don't match for pytorch HF output and QPC output"
344-
return
345356

357+
# testing for CB models
358+
if not kv_offload: # CB not yet enabled for Single QPC
359+
return
360+
images = [image] * full_batch_size
361+
queries = [query] * full_batch_size
362+
363+
streamer = TextStreamer(processor.tokenizer)
364+
pytorch_hf_tokens = api_runner.run_vlm_hf_model_on_pytorch_CB(model_hf, images, queries)
365+
366+
qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
367+
model_config["model_name"],
368+
kv_offload=kv_offload,
369+
config=config,
370+
continuous_batching=True,
371+
)
372+
373+
qeff_model.export()
374+
375+
if not get_available_device_id():
376+
pytest.skip("No available devices to run model on Cloud AI 100")
377+
378+
qeff_model.compile(
379+
img_size=model_config["img_size"],
380+
num_cores=16,
381+
num_devices=num_devices,
382+
prefill_seq_len=prompt_len,
383+
ctx_len=ctx_len,
384+
batch_size=batch_size,
385+
full_batch_size=full_batch_size,
386+
mxfp6_matmul=True,
387+
enable_qnn=enable_qnn,
388+
qnn_config=qnn_config,
389+
)
390+
391+
print("QPC Outputs (QAIC):")
392+
exec_info = qeff_model.generate(
393+
tokenizer=processor.tokenizer,
394+
processor=processor,
395+
images=[img_url] * full_batch_size,
396+
prompts=queries,
397+
generation_len=max_gen_len,
398+
)
399+
400+
qpc_tokens = exec_info.generated_ids[:, :max_gen_len]
401+
402+
for i in range(full_batch_size):
403+
assert (pytorch_hf_tokens[i] == qpc_tokens[i]).all(), f"Tokens don't match for prompt {i} between HF and QPC output"
404+
405+
return
346406

347407
def check_molmo_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
348408
model_name: str,
@@ -527,10 +587,10 @@ def check_intern_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
527587
@pytest.mark.on_qaic
528588
@pytest.mark.multimodal
529589
@pytest.mark.parametrize(
530-
"model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", test_models_config
590+
"model_name, kv_offload, batch_size, full_batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer", test_models_config
531591
)
532592
def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
533-
model_name, kv_offload, batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer
593+
model_name, kv_offload, batch_size, full_batch_size, prompt_len, ctx_len, img_size, img_url, query, n_layer
534594
):
535595
"""
536596
Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, without continuous batching.
@@ -547,6 +607,7 @@ def test_image_text_to_text_pytorch_vs_kv_vs_ort_vs_ai100(
547607
query=query,
548608
n_layer=n_layer,
549609
batch_size=batch_size,
610+
full_batch_size=full_batch_size,
550611
kv_offload=kv_offload,
551612
)
552613

0 commit comments

Comments
 (0)