Skip to content

Commit ba5be35

Browse files
authored
feat: support qwen2 and rename qwen3 layers. (#364)
1 parent ce7a10a commit ba5be35

30 files changed

+446
-649
lines changed

xllm/core/framework/parallel_state/parallel_args.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ struct ParallelArgs {
9999
// ep size
100100
PROPERTY(int32_t, ep_size) = 1;
101101

102-
#if defined(USE_NPU)
103102
// atb hccl mapping json data
104103
PROPERTY(nlohmann::json, mapping_data);
105104

105+
#if defined(USE_NPU)
106106
// atb hccl mapping
107107
PROPERTY(atb_speed::base::Mapping, mapping);
108108

xllm/core/framework/state_dict/utils.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,21 @@ void load_fused_weight(const StateDict& state_dict,
8181
int32_t world_size,
8282
std::vector<torch::Tensor>& accumulated_tensors,
8383
torch::Tensor& weight,
84-
bool& weight_is_loaded) {
84+
bool& weight_is_loaded,
85+
int32_t num_kv_head_replicas) {
8586
// return if the weight is already loaded
8687
if (weight_is_loaded) {
8788
return;
8889
}
8990

90-
weight_is_loaded = load_tensor_list(
91-
state_dict, prefixes, name, dim, rank, world_size, accumulated_tensors);
91+
weight_is_loaded = load_tensor_list(state_dict,
92+
prefixes,
93+
name,
94+
dim,
95+
rank,
96+
world_size,
97+
accumulated_tensors,
98+
num_kv_head_replicas);
9299

93100
if (weight_is_loaded) {
94101
const auto merged_weight = torch::cat(accumulated_tensors, /*dim=*/dim);
@@ -106,7 +113,8 @@ bool load_tensor_list(const StateDict& state_dict,
106113
int64_t dim,
107114
int32_t rank,
108115
int32_t world_size,
109-
std::vector<torch::Tensor>& tensors) {
116+
std::vector<torch::Tensor>& tensors,
117+
int32_t num_kv_head_replicas) {
110118
// resize the accumulated weight list if needed
111119
if (tensors.size() < prefixes.size()) {
112120
tensors.resize(prefixes.size());
@@ -118,6 +126,14 @@ bool load_tensor_list(const StateDict& state_dict,
118126
continue;
119127
}
120128

129+
// When the number of key/value heads is smaller than the number of query
130+
// heads (e.g., multi-query/grouped-query attention), the key/value head may
131+
// be replicated while the query heads are partitioned.
132+
if (i == 1 && num_kv_head_replicas > 1) {
133+
rank = rank / num_kv_head_replicas;
134+
world_size = world_size / num_kv_head_replicas;
135+
}
136+
121137
const std::string tensor_name = prefixes[i] + name;
122138
torch::Tensor tensor;
123139
if (dim < 0) {

xllm/core/framework/state_dict/utils.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,17 @@ void load_fused_weight(const StateDict& state_dict,
5656
int32_t world_size,
5757
std::vector<torch::Tensor>& accumulated_tensors,
5858
torch::Tensor& weight,
59-
bool& weight_is_loaded);
59+
bool& weight_is_loaded,
60+
int32_t num_kv_head_replicas = 1);
6061

6162
bool load_tensor_list(const StateDict& state_dict,
6263
const std::vector<std::string>& prefixes,
6364
const std::string& name,
6465
int64_t dim,
6566
int32_t rank,
6667
int32_t world_size,
67-
std::vector<torch::Tensor>& accumulated_tensors);
68+
std::vector<torch::Tensor>& accumulated_tensors,
69+
int32_t num_kv_head_replicas = 1);
6870

6971
void load_moe_weight(const StateDict& state_dict,
7072
const std::string& sub_prefix,
@@ -114,6 +116,18 @@ void load_moe_fused_weight(const StateDict& state_dict,
114116
name##_, \
115117
name##_is_loaded_);
116118

119+
#define LOAD_QKV_WEIGHT(name, dim, num_kv_head_replicas) \
120+
weight::load_fused_weight(state_dict, \
121+
prefixes, \
122+
#name, \
123+
dim, \
124+
rank, \
125+
world_size, \
126+
name##_list_, \
127+
name##_, \
128+
name##_is_loaded_, \
129+
num_kv_head_replicas);
130+
117131
#define LOAD_SHARDED_WEIGHT(name, dim) \
118132
weight::load_sharded_weight( \
119133
state_dict, #name, dim, rank, world_size, name##_, name##_is_loaded_);

xllm/core/layers/common/CMakeLists.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,31 @@ cc_library(
66
HDRS
77
flashinfer_workspace.h
88
deepseek_v2_attention.h
9-
qwen3_attention.h
9+
qwen2_attention.h
1010
attention.h
1111
fuse_norm.h
1212
rotary_embedding.h
1313
fused_moe.h
1414
dense_mlp.h
15-
qwen3_decoder_layer.h
15+
qwen2_decoder_layer.h
1616
qwen3_moe_decoder_layer.h
17-
linear_impl.h
1817
linear.h
1918
word_embedding_impl.h
2019
layer_utils.h
2120
indexer.h
2221
SRCS
2322
flashinfer_workspace.cpp
2423
deepseek_v2_attention.cpp
25-
qwen3_attention.cpp
24+
qwen2_attention.cpp
2625
attention.cpp
2726
fuse_norm.cpp
2827
rotary_embedding.cpp
2928
fused_moe.cpp
3029
dense_mlp.cpp
31-
qwen3_decoder_layer.cpp
30+
qwen2_decoder_layer.cpp
3231
qwen3_moe_decoder_layer.cpp
33-
linear_impl.cpp
32+
linear.cpp
33+
word_embedding_impl.cpp
3434
layer_utils.cpp
3535
indexer.cpp
3636
DEPS

xllm/core/layers/common/fuse_norm.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,7 @@ torch::Tensor FusedRMSNormImpl::forward_output(torch::Tensor& input,
5959
}
6060

6161
void FusedRMSNormImpl::load_state_dict(const StateDict& state_dict) {
62-
const auto weight = state_dict.get_tensor("weight");
63-
if (weight.defined()) {
64-
CHECK_EQ(weight_.sizes(), weight.sizes())
65-
<< "weight size mismatch for " << name();
66-
weight_.copy_(weight);
67-
}
62+
LOAD_WEIGHT(weight);
6863
}
6964

7065
} // namespace layer

xllm/core/layers/common/indexer.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,13 +323,5 @@ void IndexerImpl::load_state_dict(const StateDict& state_dict) {
323323
state_dict.get_dict_with_prefix("weights_proj."));
324324
}
325325

326-
// whether the weight is loaded
327-
void IndexerImpl::verify_loaded_weights(const std::string& prefix) const {
328-
// Verify that all linear layers have loaded their weights
329-
wq_b_->verify_loaded_weights(prefix + "wq_b.");
330-
wk_->verify_loaded_weights(prefix + "wk.");
331-
weights_proj_->verify_loaded_weights(prefix + "weights_proj.");
332-
}
333-
334326
} // namespace layer
335327
} // namespace xllm

xllm/core/layers/common/indexer.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,6 @@ class IndexerImpl : public torch::nn::Module {
6161
// load the weight from the checkpoint
6262
void load_state_dict(const StateDict& state_dict);
6363

64-
// whether the weight is loaded
65-
void verify_loaded_weights(const std::string& prefix = "") const;
66-
6764
private:
6865
int64_t dim_;
6966
int64_t n_heads_;

xllm/core/layers/common/layer_utils.cpp

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,6 @@ limitations under the License.
2020
namespace xllm {
2121
namespace layer {
2222

23-
bool is_dummy_run(const ModelInputParams& input_params,
24-
const ParallelArgs& parallel_args) {
25-
int64_t dp_rank = 0;
26-
if (parallel_args.dp_size() > 1) {
27-
dp_rank = parallel_args.dp_local_process_group_->rank();
28-
}
29-
if (input_params.dp_global_token_nums.size() <= 1) {
30-
return input_params.q_max_seq_len == 0;
31-
}
32-
return input_params.dp_global_token_nums[dp_rank] == 0;
33-
}
34-
3523
void update_dummy_run_input(int64_t dp_rank,
3624
torch::Tensor& positions,
3725
ModelInputParams& input_params) {
@@ -48,4 +36,4 @@ void update_dummy_run_input(int64_t dp_rank,
4836
}
4937

5038
} // namespace layer
51-
} // namespace xllm
39+
} // namespace xllm

xllm/core/layers/common/layer_utils.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,9 @@ limitations under the License.
2020
namespace xllm {
2121
namespace layer {
2222

23-
bool is_dummy_run(const ModelInputParams& input_params,
24-
const ParallelArgs& parallel_args);
25-
2623
void update_dummy_run_input(int64_t dp_rank,
2724
torch::Tensor& positions,
2825
ModelInputParams& input_params);
2926

3027
} // namespace layer
31-
} // namespace xllm
28+
} // namespace xllm

xllm/core/layers/common/linear_impl.cpp renamed to xllm/core/layers/common/linear.cpp

Lines changed: 19 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "linear_impl.h"
16+
#include "linear.h"
1717

1818
#include <glog/logging.h>
1919
#include <torch/torch.h>
@@ -82,9 +82,8 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(
8282

8383
torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) {
8484
input = input.to(device_);
85-
auto bias = (bias_.defined() && rank_ == 0)
86-
? std::optional<torch::Tensor>(bias_)
87-
: std::nullopt;
85+
auto bias =
86+
bias_.defined() ? std::optional<torch::Tensor>(bias_) : std::nullopt;
8887

8988
torch::Tensor output;
9089

@@ -148,8 +147,8 @@ torch::Tensor ColumnParallelLinearImpl::forward(torch::Tensor input) {
148147

149148
// load the weight from the checkpoint
150149
void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) {
151-
const auto rank = rank_;
152-
const auto world_size = world_size_;
150+
const int64_t rank = rank_;
151+
const int64_t world_size = world_size_;
153152

154153
// load and merge the weights on dim 0
155154
// If quant_args_ indicates SmoothQuant, load qweight; otherwise, load
@@ -172,8 +171,8 @@ void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) {
172171
void ColumnParallelLinearImpl::load_state_dict(
173172
const StateDict& state_dict,
174173
const std::vector<std::string>& prefixes) {
175-
const auto rank = rank_;
176-
const auto world_size = world_size_;
174+
const int64_t rank = rank_;
175+
const int64_t world_size = world_size_;
177176

178177
// load and merge the weights on dim 0
179178
// If quant_args_ indicates SmoothQuant, load qweight
@@ -192,7 +191,6 @@ void ColumnParallelLinearImpl::load_state_dict(
192191
break;
193192
}
194193
}
195-
196194
LOAD_FUSED_WEIGHT(qweight, 0);
197195
LOAD_FUSED_WEIGHT(per_channel_scale, 0);
198196
} else {
@@ -223,36 +221,32 @@ QKVParallelLinearImpl::QKVParallelLinearImpl(
223221
parallel_args_(parallel_args),
224222
options_(options),
225223
device_(options.device()) {
226-
const int32_t QKV_CNT = 3;
227224
rank_ = parallel_args_.tp_group_->rank();
228225
world_size_ = parallel_args_.tp_group_->world_size();
229226
const int64_t out_features_per_partition =
230227
(num_heads + 2 * num_kv_heads) * head_size;
231228
// Note: torch.nn.functional.linear performs XA^T + b and as a result
232229
// we allocate the transpose.
233-
qkv_weight_ = register_parameter(
230+
weight_ = register_parameter(
234231
"weight",
235232
torch::empty({out_features_per_partition, hidden_size}, options),
236233
/*requires_grad=*/false);
237-
qkv_weight_list_.resize(QKV_CNT);
238234

239235
if (bias) {
240-
qkv_bias_ =
236+
bias_ =
241237
register_parameter("bias",
242238
torch::empty({out_features_per_partition}, options),
243239
/*requires_grad=*/false);
244-
qkv_bias_list_.resize(QKV_CNT);
245240
}
246241
}
247242

248243
torch::Tensor QKVParallelLinearImpl::forward(torch::Tensor input) {
249244
input = input.to(device_);
250-
auto bias = (qkv_bias_.defined() && rank_ == 0)
251-
? std::optional<torch::Tensor>(qkv_bias_)
252-
: std::nullopt;
245+
auto bias =
246+
bias_.defined() ? std::optional<torch::Tensor>(bias_) : std::nullopt;
253247
xllm::kernel::MatmulParams matmul_params;
254248
matmul_params.a = input;
255-
matmul_params.b = qkv_weight_;
249+
matmul_params.b = weight_;
256250
matmul_params.bias = bias;
257251

258252
auto output = xllm::kernel::matmul(matmul_params);
@@ -262,46 +256,13 @@ torch::Tensor QKVParallelLinearImpl::forward(torch::Tensor input) {
262256
return output;
263257
}
264258

265-
bool QKVParallelLinearImpl::load_qkv_weight(const StateDict& state_dict,
266-
int32_t index) {
267-
if (qkv_weight_list_[index].defined() || state_dict.size() == 0) {
268-
return false;
269-
}
270-
DEFINE_WEIGHT(weight);
271-
int64_t out_feature = num_heads_ * head_size_;
272-
int64_t rank = rank_;
273-
int64_t world_size = world_size_;
274-
if (index > 0) {
275-
rank = rank_ / num_kv_head_replicas_;
276-
world_size = world_size_ / num_kv_head_replicas_;
277-
out_feature = num_kv_heads_ * head_size_;
278-
}
279-
weight_ = torch::empty({out_feature, hidden_size_}, options_);
280-
LOAD_SHARDED_WEIGHT(weight, 0);
281-
if (weight_is_loaded_) {
282-
qkv_weight_list_[index] = weight_.clone();
283-
}
284-
return weight_is_loaded_;
285-
}
286-
287259
void QKVParallelLinearImpl::load_state_dict(const StateDict& state_dict) {
288260
std::vector<std::string> prefixes = {"q_proj.", "k_proj.", "v_proj."};
289-
if (!qkv_weight_is_loaded_) {
290-
bool all_loaded = true;
291-
for (size_t i = 0; i < prefixes.size(); ++i) {
292-
all_loaded =
293-
all_loaded &&
294-
load_qkv_weight(state_dict.get_dict_with_prefix(prefixes[i]), i);
295-
}
296-
if (all_loaded) {
297-
const auto merged_weight = torch::cat(qkv_weight_list_, /*dim=*/0);
298-
CHECK_EQ(qkv_weight_.sizes(), merged_weight.sizes())
299-
<< "weight size mismatch";
300-
qkv_weight_.copy_(merged_weight);
301-
// release the memory for weight_list
302-
qkv_weight_list_.clear();
303-
qkv_weight_is_loaded_ = true;
304-
}
261+
const int64_t rank = rank_;
262+
const int64_t world_size = world_size_;
263+
LOAD_QKV_WEIGHT(weight, 0, num_kv_head_replicas_);
264+
if (bias_.defined()) {
265+
LOAD_QKV_WEIGHT(bias, 0, num_kv_head_replicas_);
305266
}
306267
}
307268

@@ -424,8 +385,8 @@ torch::Tensor RowParallelLinearImpl::forward(torch::Tensor input) {
424385

425386
// load the weight from the checkpoint
426387
void RowParallelLinearImpl::load_state_dict(const StateDict& state_dict) {
427-
const auto rank = rank_;
428-
const auto world_size = world_size_;
388+
const int64_t rank = rank_;
389+
const int64_t world_size = world_size_;
429390

430391
// If quant_args_ indicates SmoothQuant, load qweight; otherwise, load
431392
// normal weight.
@@ -462,7 +423,6 @@ ReplicatedLinearImpl::ReplicatedLinearImpl(
462423
}
463424

464425
torch::Tensor ReplicatedLinearImpl::forward(torch::Tensor input) {
465-
namespace F = torch::nn::functional;
466426
auto bias =
467427
bias_.defined() ? std::optional<torch::Tensor>(bias_) : std::nullopt;
468428
xllm::kernel::MatmulParams matmul_params;

0 commit comments

Comments
 (0)