Skip to content

Commit e542c22

Browse files
committed
bugfix: fix coredump issue when both prefixcache and mtp are enabled. (#377)
* bugfix: fix coredump issue when both prefixcache and mtp are enabled. * bugfix: fix coredump caused by incorrect token replacement.
1 parent bf9b671 commit e542c22

File tree

8 files changed

+47
-23
lines changed

8 files changed

+47
-23
lines changed

xllm/core/framework/model/model_args.h

100755100644
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ struct ModelArgs {
5454
PROPERTY(int64_t, vocab_size) = -1;
5555

5656
PROPERTY(bool, use_qk_norm) = false;
57-
57+
5858
PROPERTY(float, rms_norm_eps) = 0.0f;
5959

6060
PROPERTY(float, layer_norm_eps) = 0.0f;

xllm/core/framework/request/sequence_kv_state.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ void KVCacheState::add_shared_kv_blocks(std::vector<Block>&& blocks,
5858
if (blocks.empty()) {
5959
return;
6060
}
61-
6261
// The number of matched blocks may be fewer than the number of blocks held by
6362
// the sequence itself. In this case, try to replace the blocks computed by
6463
// the sequence with blocks from the prefix_cache and release the computed
@@ -86,6 +85,10 @@ void KVCacheState::add_shared_kv_blocks(std::vector<Block>&& blocks,
8685
CHECK_GT(block_size, 0);
8786
num_shared_tokens =
8887
((current_total_num_tokens - 1) / block_size) * block_size;
88+
if (num_owned_shared_blocks_ > 0) {
89+
num_owned_shared_blocks_--;
90+
blocks_.pop_back();
91+
}
8992
}
9093
CHECK_LT(num_shared_tokens, current_total_num_tokens);
9194
// update the kv cache position

xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,8 @@ void Glm4MoeDecoderImpl::initialize_basic_parameters(
382382

383383
param.mlpLinearTransposeType = {1, -1, 1, -1};
384384

385-
param.enableSplitFuse = (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill;
385+
param.enableSplitFuse =
386+
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill;
386387

387388
param.moeLinearTransposeType = (layer_id_ < args.first_k_dense_replace())
388389
? std::vector<int>{-1, -1, -1, -1}
@@ -406,7 +407,7 @@ void Glm4MoeDecoderImpl::initialize_basic_parameters(
406407
param.enableSwiGLUQuantForSharedExperts = false; // TODO
407408

408409
param.useQKNorm = args.use_qk_norm();
409-
if(args.use_qk_norm()){
410+
if (args.use_qk_norm()) {
410411
WEIGHT_COUNT_PER_LAYER = 70;
411412
WEIGHT_MAPPING_W8A8["self_attn.q_norm.weight"] = Q_NORM_WEIGHT;
412413
WEIGHT_MAPPING_W8A8["self_attn.k_norm.weight"] = K_NORM_WEIGHT;

xllm/core/runtime/llm_worker_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ std::optional<ForwardOutput> LLMWorkerImpl::step(
189189
// should be in same prefill stage, so, to judge empty_kv_cache,
190190
// just use micro batch 0 here
191191
if (options_.enable_speculative_decode() && !is_spec_draft_) {
192-
if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) {
192+
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
193193
output.sample_output.embeddings = hidden_states;
194194
} else if (concated_sampling_params.sample_idxes.defined()) {
195195
// auto sample_idxes =

xllm/core/runtime/speculative_worker_impl.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
173173
}
174174

175175
// TODO: support data parallel case
176-
if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) {
176+
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
177177
return step_prefill(inputs);
178178
} else {
179179
return step_decode(inputs);
@@ -182,7 +182,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
182182

183183
std::optional<ForwardOutput> SpeculativeWorkerImpl::step_empty(
184184
const BatchedForwardInputs& inputs) {
185-
if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) {
185+
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
186186
auto output = impl_->step(inputs);
187187
auto draft_output = draft_impl_->step(inputs);
188188
return output;
@@ -230,7 +230,8 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
230230
if (token_offset > 0) {
231231
prefill_inputs.micro_inputs[i].input_params.mm_data = MMData(
232232
MMType::EMBEDDING,
233-
{{"embedding", embeddings.narrow(0, token_start_idx, token_offset)}});
233+
{{"embedding",
234+
embeddings.narrow(0, token_start_idx, token_offset).clone()}});
234235
}
235236
if (next_tokens.defined()) {
236237
auto& token_ids = prefill_inputs.micro_inputs[i].token_ids;
@@ -293,6 +294,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
293294
void SpeculativeWorkerImpl::prepare_prefill_inputs(
294295
const BatchedForwardInputs& inputs,
295296
BatchedForwardInputs& prefill_inputs) {
297+
prefill_inputs.micro_inputs.clear();
296298
prefill_inputs.micro_inputs.reserve(inputs.micro_inputs.size());
297299
for (auto i = 0; i < inputs.micro_inputs.size(); ++i) {
298300
auto& input = inputs.micro_inputs[i];
@@ -308,16 +310,16 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs(
308310
int32_t start_idx = 0;
309311
std::vector<int32_t> new_token_ids;
310312
new_token_ids.reserve(input.token_ids.numel());
311-
for (size_t i = 0; i < input_params.num_sequences; ++i) {
313+
for (size_t j = 0; j < input_params.num_sequences; ++j) {
312314
int32_t q_len = 0;
313-
q_len = input_params.q_seq_lens_vec[i];
315+
q_len = input_params.q_seq_lens_vec[j];
314316
Slice<int32_t> tokens_ids_slice_i =
315317
tokens_ids_slice.slice(start_idx + 1, start_idx + q_len);
316318
start_idx += q_len;
317319
new_token_ids.insert(new_token_ids.end(),
318320
tokens_ids_slice_i.begin(),
319321
tokens_ids_slice_i.end());
320-
new_token_ids.emplace_back(extra_token_ids[i]);
322+
new_token_ids.emplace_back(extra_token_ids[j]);
321323
}
322324
prefill_input.token_ids =
323325
torch::tensor(new_token_ids, prefill_input.positions.options());
@@ -359,7 +361,11 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
359361
// final step
360362
prepare_validate_inputs(inputs, validate_inputs, true);
361363
} else {
362-
prepare_draft_inputs(draft_inputs, next_step_input, 1, device_);
364+
if (i == 0) {
365+
prepare_draft_inputs(inputs, next_step_input, 1, device_);
366+
} else {
367+
prepare_draft_inputs(draft_inputs, next_step_input, 1, device_);
368+
}
363369
}
364370
draft_outputs.push_back(std::move(future).get().value());
365371
// update input of next step
@@ -368,8 +374,8 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
368374
auto last_output = draft_outputs.back().sample_output;
369375
auto start_idx = 0;
370376
auto token_start_idx = 0;
371-
for (auto i = 0; i < draft_inputs.micro_inputs.size(); ++i) {
372-
auto& draft_input = draft_inputs.micro_inputs[i];
377+
for (auto j = 0; j < draft_inputs.micro_inputs.size(); ++j) {
378+
auto& draft_input = draft_inputs.micro_inputs[j];
373379
auto offset = draft_input.input_params.num_sequences;
374380
auto token_offset = draft_input.token_ids.size(0);
375381
draft_input.token_ids = safe_to(
@@ -379,6 +385,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
379385
MMType::EMBEDDING,
380386
{{"embedding",
381387
last_output.embeddings.narrow(0, token_start_idx, token_offset)
388+
.clone()
382389
.to(device_)}});
383390
}
384391
start_idx += offset;
@@ -394,9 +401,11 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
394401
auto next_tokens =
395402
safe_to(draft_output.sample_output.next_tokens, torch::kInt);
396403
int32_t start_idx = 0;
397-
for (auto i = 0; i < validate_inputs.micro_inputs.size(); ++i) {
398-
int32_t offset = draft_inputs.micro_inputs[i].input_params.num_sequences;
399-
auto& validate_input = validate_inputs.micro_inputs[i];
404+
for (auto j = 0; j < validate_inputs.micro_inputs.size(); ++j) {
405+
int32_t offset =
406+
validate_inputs.micro_inputs[j].input_params.num_sequences /
407+
(options_.num_speculative_tokens() + 1);
408+
auto& validate_input = validate_inputs.micro_inputs[j];
400409
auto& token_ids = validate_input.token_ids;
401410
auto mask = (token_ids == -1 * (i + 1));
402411
token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, offset));
@@ -447,9 +456,10 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(
447456
const int64_t offset,
448457
const torch::Device device) {
449458
// prepare input for MTP in decoding phase (Like Eagle).
459+
draft_inputs.micro_inputs.clear();
450460
draft_inputs.micro_inputs.reserve(inputs.micro_inputs.size());
451-
for (auto i = 0; i < inputs.micro_inputs.size(); ++i) {
452-
auto& input = inputs.micro_inputs[i];
461+
for (auto idx = 0; idx < inputs.micro_inputs.size(); ++idx) {
462+
auto& input = inputs.micro_inputs[idx];
453463
ForwardInput draft_input = input.to(device, dtype_);
454464

455465
auto& input_params = draft_input.input_params;
@@ -504,8 +514,8 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
504514
BatchedForwardInputs& validate_inputs,
505515
bool enable_schedule_overlap) {
506516
validate_inputs.micro_inputs.reserve(inputs.micro_inputs.size());
507-
for (auto i = 0; i < inputs.micro_inputs.size(); ++i) {
508-
auto& input = inputs.micro_inputs[i];
517+
for (auto idx = 0; idx < inputs.micro_inputs.size(); ++idx) {
518+
auto& input = inputs.micro_inputs[idx];
509519

510520
ForwardInput validate_input = input.to(device_, dtype_);
511521
auto& input_params = validate_input.input_params;
@@ -823,7 +833,7 @@ void SpeculativeWorkerImpl::update_sampling_params(
823833
void SpeculativeWorkerImpl::prepare_work_before_execute(
824834
const BatchedForwardInputs& inputs,
825835
BatchedForwardInputs& processed_inputs) {
826-
if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) {
836+
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
827837
WorkerImpl::prepare_work_before_execute(inputs, processed_inputs);
828838
} else {
829839
if (enable_schedule_overlap()) {

xllm/core/runtime/worker_impl.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,4 +780,13 @@ int64_t WorkerImpl::get_active_activation_memory() {
780780
.active_activation_memory;
781781
}
782782

783+
bool WorkerImpl::check_is_prefill(const std::vector<int>& q_seq_lens_vec) {
784+
for (auto q_len : q_seq_lens_vec) {
785+
if (q_len > 1) {
786+
return true;
787+
}
788+
}
789+
return false;
790+
}
791+
783792
} // namespace xllm

xllm/core/runtime/worker_impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class WorkerImpl {
166166

167167
torch::ScalarType dtype() const { return dtype_; }
168168

169+
bool check_is_prefill(const std::vector<int>& q_seq_lens_vec);
170+
169171
int32_t hidden_size() const {
170172
return context_.get_model_args().hidden_size();
171173
}

xllm/core/scheduler/continuous_scheduler.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ ContinuousScheduler::ContinuousScheduler(Engine* engine, const Options& options)
9393
} else {
9494
min_speculative_tokens_required_ = options_.num_speculative_tokens();
9595
}
96-
9796
}
9897

9998
ContinuousScheduler::~ContinuousScheduler() { running_requests_.clear(); }

0 commit comments

Comments
 (0)