Skip to content

Commit f4ff803

Browse files
authored
Adding Compute-Context-Length (CCL) (#576)
Compute-Context-Length (CCL) technique optimizes the throughput of large language models (LLMs) on Qualcomm devices when handling very large context lengths. The current Ahead Of Time (AOT) compilation on Qualcomm devices doesn't predict the number of tokens needed, leading to significant throughput drops during the prefilling and the decoding phases. This happens because the system performs attention calculations based on large context length. To address this issue, we introduce Compute Context Length (CCL), an additional ONNX variable that allows for dynamic context-length specialization. By generating tokens using smaller, more manageable context lengths (CCL), we optimize memory reads and attention calculations, thereby improving throughput. --------- Signed-off-by: Vahid Janfaza <[email protected]>
1 parent c788f17 commit f4ff803

File tree

56 files changed

+3189
-304
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+3189
-304
lines changed

QEfficient/cloud/infer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,18 @@ def main(
340340
"--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation."
341341
)
342342
parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.")
343+
parser.add_argument(
344+
"--comp-ctx-lengths-prefill",
345+
type=lambda comp_ctx_lengths_prefill: [int(x) for x in comp_ctx_lengths_prefill.split(",")],
346+
default=[512],
347+
help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).",
348+
)
349+
parser.add_argument(
350+
"--comp-ctx-lengths-decode",
351+
type=lambda comp_ctx_lengths_decode: [int(x) for x in comp_ctx_lengths_decode.split(",")],
352+
default=[2048],
353+
help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).",
354+
)
343355
parser.add_argument(
344356
"--mxfp6",
345357
"--mxfp6_matmul",

QEfficient/customop/ctx_scatter_gather.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor
115115

116116

117117
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
118-
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
119-
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
118+
def CtxGather(
119+
data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
120+
) -> onnxscript.FLOAT:
121+
# Create a shape tensor based on comp_ctx_len
122+
shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)
123+
124+
# Directly use the shape tensor without validation
125+
ctx_indices = ops.Expand(ctx_indices, shape_tensor)
120126
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
121127
return ops.GatherND(data, ctx_indices, batch_dims=2)
122128

@@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
127133
"""
128134

129135
@staticmethod
130-
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
136+
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
131137
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
132138
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
133139
return data[batch_indices, head_indices, ctx_indices]
@@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
137143
pass
138144

139145
@staticmethod
140-
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
141-
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
146+
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
147+
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)

QEfficient/customop/ctx_scatter_gather_cb.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,20 @@ def symbolic(
9797

9898
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
9999
def CtxGatherCB(
100-
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
100+
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
101101
) -> onnxscript.FLOAT:
102102
batch_size = ops.Gather(ops.Shape(batch_index), [0])
103103
num_heads = ops.Gather(ops.Shape(data), [1])
104-
ctx_len = ops.Gather(ops.Shape(data), [2])
104+
# using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
105+
ctx_len = ops.Reshape(comp_ctx_len, [1])
105106

106107
# Expanded shape to create indices
107108
zero = ops.Constant(value_ints=[0])
108109
one = ops.Constant(value_ints=[1])
109-
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
110+
# exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
111+
exp_shape = ops.Concat(
112+
ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0
113+
)
110114

111115
# Create indices
112116
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
@@ -119,7 +123,7 @@ def CtxGatherCB(
119123

120124
class CtxGatherFuncCB(torch.autograd.Function):
121125
@staticmethod
122-
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
126+
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
123127
batch_indices = batch_index.view(-1, 1, 1)
124128
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
125129
return data[batch_indices, head_indices, ctx_indices]
@@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs):
129133
pass
130134

131135
@staticmethod
132-
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
133-
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
136+
def symbolic(
137+
g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
138+
) -> torch.Value:
139+
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)
134140

135141

136142
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))

QEfficient/generation/text_generation_inference.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
318318
prompts_txt_file_path: Optional[str] = None,
319319
device_id: Optional[List[int]] = None,
320320
generation_len: Optional[int] = None,
321+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
322+
comp_ctx_lengths_decode: Optional[List[int]] = None,
321323
enable_debug_logs: bool = False,
322324
stream: bool = True,
323325
write_io_dir: Optional[str] = None,
@@ -384,6 +386,8 @@ def cloud_ai_100_exec_kv(
384386
qpc_path=qpc_path,
385387
device_id=device_id,
386388
ctx_len=ctx_len,
389+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
390+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
387391
enable_debug_logs=enable_debug_logs,
388392
write_io_dir=write_io_dir,
389393
full_batch_size=full_batch_size,
@@ -430,6 +434,8 @@ def __init__(
430434
qpc_path: str,
431435
full_batch_size: Optional[int] = None,
432436
ctx_len: Optional[int] = None,
437+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
438+
comp_ctx_lengths_decode: Optional[List[int]] = None,
433439
device_id: Optional[List[int]] = None,
434440
enable_debug_logs: bool = False,
435441
write_io_dir: Optional[str] = None,
@@ -440,6 +446,8 @@ def __init__(
440446
activate: bool = True,
441447
) -> None:
442448
self._ctx_len = ctx_len
449+
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
450+
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
443451
self._write_io_dir = write_io_dir
444452
self.is_tlm = is_tlm
445453
self.return_pdfs = return_pdfs
@@ -802,7 +810,17 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
802810
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
803811
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
804812

813+
if self.comp_ctx_lengths_prefill is not None:
814+
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
815+
prefill_ccl_id = 0
816+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
817+
805818
for i in range(num_chunks):
819+
if self.comp_ctx_lengths_prefill is not None:
820+
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
821+
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
822+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
823+
806824
chunk_inputs = inputs.copy()
807825
chunk_inputs["input_ids"] = inputs["input_ids"][
808826
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
@@ -822,6 +840,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
822840
generation_len,
823841
)
824842

843+
def initialize_ccl(self, decode_inputs):
844+
self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
845+
max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
846+
max_position_id = np.max(decode_inputs["position_ids"])
847+
ccl_id_initial = 0
848+
ccl_id = ccl_id_initial
849+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
850+
if max_position_id < self.comp_ctx_lengths_decode[i]:
851+
ccl_id = i
852+
break
853+
854+
return ccl_id, max_ccl_id
855+
825856
def run_continuous_batching_decode(self, prompt_queue, generation_len):
826857
"""
827858
Runs continuous batching decode for the given prompt queue and generation length.
@@ -853,6 +884,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
853884
# Prepare decode inputs inputs.
854885
decode_inputs = self.prepare_decode_inputs()
855886

887+
if self.comp_ctx_lengths_decode is not None:
888+
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
889+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
890+
856891
while prompt_queue or current_decode_ongoing.any():
857892
outputs = self._session.run(decode_inputs)
858893

@@ -890,6 +925,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
890925
batch_id_map[decode_batch_id]
891926
]
892927

928+
if self.comp_ctx_lengths_decode is not None:
929+
###Recalculate ccl_id based on position ids###
930+
# Determine the maximum value of position_ids across all batch elements
931+
max_position_id = np.max(decode_inputs["position_ids"])
932+
933+
# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
934+
ccl_id_initial = 0
935+
ccl_id = ccl_id_initial
936+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
937+
if max_position_id < self.comp_ctx_lengths_decode[i]:
938+
ccl_id = i
939+
break
940+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
941+
893942
else:
894943
current_decode_ongoing[decode_batch_id] = False
895944
else:
@@ -902,6 +951,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
902951
if self.include_sampler:
903952
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
904953

954+
if self.comp_ctx_lengths_decode is not None:
955+
# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
956+
if (
957+
decode_inputs["position_ids"][decode_batch_id, -1]
958+
>= self.comp_ctx_lengths_decode[ccl_id] - 1
959+
):
960+
ccl_id = min(ccl_id + 1, max_ccl_id)
961+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
962+
905963
generated_id_current_index[decode_batch_id] += 1
906964

907965
return decode_pause_time
@@ -928,7 +986,18 @@ def run_decode(
928986
self._session.set_buffers({"logits": logits_out_placeholder})
929987
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
930988
num_token = 0
989+
990+
if self.comp_ctx_lengths_decode is not None:
991+
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
992+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
993+
994+
cache_index = np.max(decode_inputs["position_ids"])
931995
for num_token in range(1, generation_len):
996+
if self.comp_ctx_lengths_decode is not None:
997+
if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1:
998+
ccl_id = min(ccl_id + 1, max_ccl_id)
999+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
1000+
9321001
if streamer:
9331002
streamer.put(decode_inputs["input_ids"][0])
9341003
outputs = self._session.run(decode_inputs)
@@ -940,6 +1009,7 @@ def run_decode(
9401009
# Prepare inputs for next iteration
9411010
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
9421011
decode_inputs["position_ids"][:, -1] += 1
1012+
cache_index += 1
9431013
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
9441014
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
9451015
if self.include_sampler:
@@ -989,6 +1059,8 @@ def __init__(
9891059
qpc_path: str,
9901060
full_batch_size: Optional[int] = None,
9911061
ctx_len: Optional[int] = None,
1062+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
1063+
comp_ctx_lengths_decode: Optional[List[int]] = None,
9921064
device_id: Optional[List[int]] = None,
9931065
enable_debug_logs: bool = False,
9941066
write_io_dir: Optional[str] = None,
@@ -1002,6 +1074,8 @@ def __init__(
10021074
qpc_path=qpc_path,
10031075
full_batch_size=full_batch_size,
10041076
ctx_len=ctx_len,
1077+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
1078+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
10051079
device_id=device_id,
10061080
enable_debug_logs=enable_debug_logs,
10071081
write_io_dir=write_io_dir,
@@ -1013,6 +1087,8 @@ def __init__(
10131087
self._full_batch_size = self._qaic_model.full_batch_size
10141088
self._tokenizer = self._qaic_model.tokenizer
10151089
self._ctx_len = ctx_len
1090+
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1091+
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
10161092
self._perf_metrics = None
10171093
self._prompt_queue = None
10181094
self._text_streamer = None

QEfficient/generation/vlm_generation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(
8383
vision_qpc_path: str,
8484
device_id: Optional[List[int]] = None,
8585
ctx_len: Optional[int] = None,
86+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
87+
comp_ctx_lengths_decode: Optional[List[int]] = None,
8688
enable_debug_logs: bool = False,
8789
write_io_dir: Optional[str] = None,
8890
full_batch_size: Optional[int] = None,
@@ -123,6 +125,8 @@ def __init__(
123125
qpc_path=lang_qpc_path,
124126
full_batch_size=full_batch_size,
125127
ctx_len=ctx_len,
128+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
129+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
126130
device_id=device_id,
127131
enable_debug_logs=enable_debug_logs,
128132
write_io_dir=write_io_dir,
@@ -294,6 +298,11 @@ def _execute_chunked_prefill(
294298
outputs = None
295299
chunk_image_idx = None
296300

301+
if self.comp_ctx_lengths_prefill is not None:
302+
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
303+
prefill_ccl_id = 0
304+
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
305+
297306
for i in range(num_chunks):
298307
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
299308
position_ids_slice = lang_inputs["position_ids"][
@@ -312,6 +321,13 @@ def _execute_chunked_prefill(
312321
if "cross_attention_mask" in lang_inputs:
313322
chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"]
314323

324+
if self.comp_ctx_lengths_prefill is not None:
325+
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
326+
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
327+
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
328+
329+
chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"]
330+
315331
outputs = self._session.run(chunk_inputs)
316332

317333
if "image_idx_output" in outputs:

QEfficient/peft/lora/layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ def forward(self, x: torch.Tensor, lora_ids: torch.Tensor):
4242
# multilora implementation: lora_ids <batch_size, 1>
4343
other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1)
4444
selected_lora_a_weights = CtxGatherFuncCB.apply(
45-
self.lora_a_weights, lora_ids, other_indices_a
45+
self.lora_a_weights, lora_ids, other_indices_a, self.lora_a_weights.shape[2]
4646
) # <num_loras, 1, feature, r>
4747
other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1)
4848
selected_lora_b_weights = CtxGatherFuncCB.apply(
49-
self.lora_b_weights, lora_ids, other_indices_b
49+
self.lora_b_weights, lora_ids, other_indices_b, self.lora_b_weights.shape[2]
5050
) # <num_loras, 1, r, feature>
5151
other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1)
5252
selected_lora_scalings = CtxGatherFuncCB.apply(
53-
self.lora_scalings, lora_ids, other_indices_s
53+
self.lora_scalings, lora_ids, other_indices_s, self.lora_scalings.shape[2]
5454
) # <num_loras, 1, 1, 1>
5555

5656
selected_lora_a_weights = selected_lora_a_weights.squeeze(1)

0 commit comments

Comments
 (0)