@@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
318318 prompts_txt_file_path : Optional [str ] = None ,
319319 device_id : Optional [List [int ]] = None ,
320320 generation_len : Optional [int ] = None ,
321+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
322+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
321323 enable_debug_logs : bool = False ,
322324 stream : bool = True ,
323325 write_io_dir : Optional [str ] = None ,
@@ -384,6 +386,8 @@ def cloud_ai_100_exec_kv(
384386 qpc_path = qpc_path ,
385387 device_id = device_id ,
386388 ctx_len = ctx_len ,
389+ comp_ctx_lengths_prefill = comp_ctx_lengths_prefill ,
390+ comp_ctx_lengths_decode = comp_ctx_lengths_decode ,
387391 enable_debug_logs = enable_debug_logs ,
388392 write_io_dir = write_io_dir ,
389393 full_batch_size = full_batch_size ,
@@ -430,6 +434,8 @@ def __init__(
430434 qpc_path : str ,
431435 full_batch_size : Optional [int ] = None ,
432436 ctx_len : Optional [int ] = None ,
437+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
438+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
433439 device_id : Optional [List [int ]] = None ,
434440 enable_debug_logs : bool = False ,
435441 write_io_dir : Optional [str ] = None ,
@@ -440,6 +446,8 @@ def __init__(
440446 activate : bool = True ,
441447 ) -> None :
442448 self ._ctx_len = ctx_len
449+ self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
450+ self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
443451 self ._write_io_dir = write_io_dir
444452 self .is_tlm = is_tlm
445453 self .return_pdfs = return_pdfs
@@ -802,7 +810,17 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
802810 batch_lora_ids = [self ._prompt_to_lora_id_mapping_prefill .popleft () for i in range (self .batch_size )]
803811 inputs ["lora_ids" ] = np .array (batch_lora_ids , dtype = np .int64 ).reshape (self .batch_size , 1 )
804812
813+ if self .comp_ctx_lengths_prefill is not None :
814+ self .list_of_comp_ctx_lengths_prefill = [np .zeros (length ) for length in self .comp_ctx_lengths_prefill ]
815+ prefill_ccl_id = 0
816+ inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_prefill [prefill_ccl_id ]
817+
805818 for i in range (num_chunks ):
819+ if self .comp_ctx_lengths_prefill is not None :
820+ if (i + 1 ) * self ._prefill_seq_len > self .comp_ctx_lengths_prefill [prefill_ccl_id ]:
821+ prefill_ccl_id = min (prefill_ccl_id + 1 , len (self .comp_ctx_lengths_prefill ) - 1 )
822+ inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_prefill [prefill_ccl_id ]
823+
806824 chunk_inputs = inputs .copy ()
807825 chunk_inputs ["input_ids" ] = inputs ["input_ids" ][
808826 :, i * self ._prefill_seq_len : (i + 1 ) * self ._prefill_seq_len
@@ -822,6 +840,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
822840 generation_len ,
823841 )
824842
843+ def initialize_ccl (self , decode_inputs ):
844+ self .list_of_comp_ctx_lengths_decode = [np .zeros (length ) for length in self .comp_ctx_lengths_decode ]
845+ max_ccl_id = len (self .comp_ctx_lengths_decode ) - 1
846+ max_position_id = np .max (decode_inputs ["position_ids" ])
847+ ccl_id_initial = 0
848+ ccl_id = ccl_id_initial
849+ for i in range (ccl_id_initial , len (self .comp_ctx_lengths_decode )):
850+ if max_position_id < self .comp_ctx_lengths_decode [i ]:
851+ ccl_id = i
852+ break
853+
854+ return ccl_id , max_ccl_id
855+
825856 def run_continuous_batching_decode (self , prompt_queue , generation_len ):
826857 """
827858 Runs continuous batching decode for the given prompt queue and generation length.
@@ -853,6 +884,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
853884 # Prepare decode inputs inputs.
854885 decode_inputs = self .prepare_decode_inputs ()
855886
887+ if self .comp_ctx_lengths_decode is not None :
888+ ccl_id , max_ccl_id = self .initialize_ccl (decode_inputs )
889+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
890+
856891 while prompt_queue or current_decode_ongoing .any ():
857892 outputs = self ._session .run (decode_inputs )
858893
@@ -890,6 +925,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
890925 batch_id_map [decode_batch_id ]
891926 ]
892927
928+ if self .comp_ctx_lengths_decode is not None :
929+ ###Recalculate ccl_id based on position ids###
930+ # Determine the maximum value of position_ids across all batch elements
931+ max_position_id = np .max (decode_inputs ["position_ids" ])
932+
933+ # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
934+ ccl_id_initial = 0
935+ ccl_id = ccl_id_initial
936+ for i in range (ccl_id_initial , len (self .comp_ctx_lengths_decode )):
937+ if max_position_id < self .comp_ctx_lengths_decode [i ]:
938+ ccl_id = i
939+ break
940+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
941+
893942 else :
894943 current_decode_ongoing [decode_batch_id ] = False
895944 else :
@@ -902,6 +951,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
902951 if self .include_sampler :
903952 decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
904953
954+ if self .comp_ctx_lengths_decode is not None :
955+ # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
956+ if (
957+ decode_inputs ["position_ids" ][decode_batch_id , - 1 ]
958+ >= self .comp_ctx_lengths_decode [ccl_id ] - 1
959+ ):
960+ ccl_id = min (ccl_id + 1 , max_ccl_id )
961+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
962+
905963 generated_id_current_index [decode_batch_id ] += 1
906964
907965 return decode_pause_time
@@ -928,7 +986,18 @@ def run_decode(
928986 self ._session .set_buffers ({"logits" : logits_out_placeholder })
929987 finished_sequences = decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
930988 num_token = 0
989+
990+ if self .comp_ctx_lengths_decode is not None :
991+ ccl_id , max_ccl_id = self .initialize_ccl (decode_inputs )
992+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
993+
994+ cache_index = np .max (decode_inputs ["position_ids" ])
931995 for num_token in range (1 , generation_len ):
996+ if self .comp_ctx_lengths_decode is not None :
997+ if cache_index >= self .comp_ctx_lengths_decode [ccl_id ] - 1 :
998+ ccl_id = min (ccl_id + 1 , max_ccl_id )
999+ decode_inputs ["comp_ctx_lengths" ] = self .list_of_comp_ctx_lengths_decode [ccl_id ]
1000+
9321001 if streamer :
9331002 streamer .put (decode_inputs ["input_ids" ][0 ])
9341003 outputs = self ._session .run (decode_inputs )
@@ -940,6 +1009,7 @@ def run_decode(
9401009 # Prepare inputs for next iteration
9411010 decode_inputs ["input_ids" ] = self ._fetch_next_token_id (outputs )
9421011 decode_inputs ["position_ids" ][:, - 1 ] += 1
1012+ cache_index += 1
9431013 self .generated_ids [:, num_token ] = decode_inputs ["input_ids" ][:, - 1 ]
9441014 finished_sequences |= decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
9451015 if self .include_sampler :
@@ -989,6 +1059,8 @@ def __init__(
9891059 qpc_path : str ,
9901060 full_batch_size : Optional [int ] = None ,
9911061 ctx_len : Optional [int ] = None ,
1062+ comp_ctx_lengths_prefill : Optional [List [int ]] = None ,
1063+ comp_ctx_lengths_decode : Optional [List [int ]] = None ,
9921064 device_id : Optional [List [int ]] = None ,
9931065 enable_debug_logs : bool = False ,
9941066 write_io_dir : Optional [str ] = None ,
@@ -1002,6 +1074,8 @@ def __init__(
10021074 qpc_path = qpc_path ,
10031075 full_batch_size = full_batch_size ,
10041076 ctx_len = ctx_len ,
1077+ comp_ctx_lengths_prefill = comp_ctx_lengths_prefill ,
1078+ comp_ctx_lengths_decode = comp_ctx_lengths_decode ,
10051079 device_id = device_id ,
10061080 enable_debug_logs = enable_debug_logs ,
10071081 write_io_dir = write_io_dir ,
@@ -1013,6 +1087,8 @@ def __init__(
10131087 self ._full_batch_size = self ._qaic_model .full_batch_size
10141088 self ._tokenizer = self ._qaic_model .tokenizer
10151089 self ._ctx_len = ctx_len
1090+ self .comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1091+ self .comp_ctx_lengths_decode = comp_ctx_lengths_decode
10161092 self ._perf_metrics = None
10171093 self ._prompt_queue = None
10181094 self ._text_streamer = None
0 commit comments